You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by rx...@apache.org on 2016/02/29 02:25:29 UTC
[13/14] spark git commit: [SPARK-13529][BUILD] Move network/* modules
into common/network-*
http://git-wip-us.apache.org/repos/asf/spark/blob/9e01dcc6/common/network-common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java
----------------------------------------------------------------------
diff --git a/common/network-common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java b/common/network-common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java
new file mode 100644
index 0000000..f0e2004
--- /dev/null
+++ b/common/network-common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java
@@ -0,0 +1,251 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.client;
+
+import java.io.IOException;
+import java.util.Map;
+import java.util.Queue;
+import java.util.concurrent.ConcurrentHashMap;
+import java.util.concurrent.ConcurrentLinkedQueue;
+import java.util.concurrent.atomic.AtomicLong;
+
+import com.google.common.annotations.VisibleForTesting;
+import io.netty.channel.Channel;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.spark.network.protocol.ChunkFetchFailure;
+import org.apache.spark.network.protocol.ChunkFetchSuccess;
+import org.apache.spark.network.protocol.ResponseMessage;
+import org.apache.spark.network.protocol.RpcFailure;
+import org.apache.spark.network.protocol.RpcResponse;
+import org.apache.spark.network.protocol.StreamChunkId;
+import org.apache.spark.network.protocol.StreamFailure;
+import org.apache.spark.network.protocol.StreamResponse;
+import org.apache.spark.network.server.MessageHandler;
+import org.apache.spark.network.util.NettyUtils;
+import org.apache.spark.network.util.TransportFrameDecoder;
+
+/**
+ * Handler that processes server responses, in response to requests issued from a
+ * [[TransportClient]]. It works by tracking the list of outstanding requests (and their callbacks).
+ *
+ * Concurrency: thread safe and can be called from multiple threads.
+ */
+public class TransportResponseHandler extends MessageHandler<ResponseMessage> {
+ private final Logger logger = LoggerFactory.getLogger(TransportResponseHandler.class);
+
+ private final Channel channel;
+
+ private final Map<StreamChunkId, ChunkReceivedCallback> outstandingFetches;
+
+ private final Map<Long, RpcResponseCallback> outstandingRpcs;
+
+ private final Queue<StreamCallback> streamCallbacks;
+ private volatile boolean streamActive;
+
+ /** Records the time (in system nanoseconds) that the last fetch or RPC request was sent. */
+ private final AtomicLong timeOfLastRequestNs;
+
+ public TransportResponseHandler(Channel channel) {
+ this.channel = channel;
+ this.outstandingFetches = new ConcurrentHashMap<StreamChunkId, ChunkReceivedCallback>();
+ this.outstandingRpcs = new ConcurrentHashMap<Long, RpcResponseCallback>();
+ this.streamCallbacks = new ConcurrentLinkedQueue<StreamCallback>();
+ this.timeOfLastRequestNs = new AtomicLong(0);
+ }
+
+ public void addFetchRequest(StreamChunkId streamChunkId, ChunkReceivedCallback callback) {
+ updateTimeOfLastRequest();
+ outstandingFetches.put(streamChunkId, callback);
+ }
+
+ public void removeFetchRequest(StreamChunkId streamChunkId) {
+ outstandingFetches.remove(streamChunkId);
+ }
+
+ public void addRpcRequest(long requestId, RpcResponseCallback callback) {
+ updateTimeOfLastRequest();
+ outstandingRpcs.put(requestId, callback);
+ }
+
+ public void removeRpcRequest(long requestId) {
+ outstandingRpcs.remove(requestId);
+ }
+
+ public void addStreamCallback(StreamCallback callback) {
+ timeOfLastRequestNs.set(System.nanoTime());
+ streamCallbacks.offer(callback);
+ }
+
+ @VisibleForTesting
+ public void deactivateStream() {
+ streamActive = false;
+ }
+
+ /**
+ * Fire the failure callback for all outstanding requests. This is called when we have an
+ * uncaught exception or pre-mature connection termination.
+ */
+ private void failOutstandingRequests(Throwable cause) {
+ for (Map.Entry<StreamChunkId, ChunkReceivedCallback> entry : outstandingFetches.entrySet()) {
+ entry.getValue().onFailure(entry.getKey().chunkIndex, cause);
+ }
+ for (Map.Entry<Long, RpcResponseCallback> entry : outstandingRpcs.entrySet()) {
+ entry.getValue().onFailure(cause);
+ }
+
+ // It's OK if new fetches appear, as they will fail immediately.
+ outstandingFetches.clear();
+ outstandingRpcs.clear();
+ }
+
+ @Override
+ public void channelActive() {
+ }
+
+ @Override
+ public void channelInactive() {
+ if (numOutstandingRequests() > 0) {
+ String remoteAddress = NettyUtils.getRemoteAddress(channel);
+ logger.error("Still have {} requests outstanding when connection from {} is closed",
+ numOutstandingRequests(), remoteAddress);
+ failOutstandingRequests(new IOException("Connection from " + remoteAddress + " closed"));
+ }
+ }
+
+ @Override
+ public void exceptionCaught(Throwable cause) {
+ if (numOutstandingRequests() > 0) {
+ String remoteAddress = NettyUtils.getRemoteAddress(channel);
+ logger.error("Still have {} requests outstanding when connection from {} is closed",
+ numOutstandingRequests(), remoteAddress);
+ failOutstandingRequests(cause);
+ }
+ }
+
+ @Override
+ public void handle(ResponseMessage message) throws Exception {
+ String remoteAddress = NettyUtils.getRemoteAddress(channel);
+ if (message instanceof ChunkFetchSuccess) {
+ ChunkFetchSuccess resp = (ChunkFetchSuccess) message;
+ ChunkReceivedCallback listener = outstandingFetches.get(resp.streamChunkId);
+ if (listener == null) {
+ logger.warn("Ignoring response for block {} from {} since it is not outstanding",
+ resp.streamChunkId, remoteAddress);
+ resp.body().release();
+ } else {
+ outstandingFetches.remove(resp.streamChunkId);
+ listener.onSuccess(resp.streamChunkId.chunkIndex, resp.body());
+ resp.body().release();
+ }
+ } else if (message instanceof ChunkFetchFailure) {
+ ChunkFetchFailure resp = (ChunkFetchFailure) message;
+ ChunkReceivedCallback listener = outstandingFetches.get(resp.streamChunkId);
+ if (listener == null) {
+ logger.warn("Ignoring response for block {} from {} ({}) since it is not outstanding",
+ resp.streamChunkId, remoteAddress, resp.errorString);
+ } else {
+ outstandingFetches.remove(resp.streamChunkId);
+ listener.onFailure(resp.streamChunkId.chunkIndex, new ChunkFetchFailureException(
+ "Failure while fetching " + resp.streamChunkId + ": " + resp.errorString));
+ }
+ } else if (message instanceof RpcResponse) {
+ RpcResponse resp = (RpcResponse) message;
+ RpcResponseCallback listener = outstandingRpcs.get(resp.requestId);
+ if (listener == null) {
+ logger.warn("Ignoring response for RPC {} from {} ({} bytes) since it is not outstanding",
+ resp.requestId, remoteAddress, resp.body().size());
+ } else {
+ outstandingRpcs.remove(resp.requestId);
+ try {
+ listener.onSuccess(resp.body().nioByteBuffer());
+ } finally {
+ resp.body().release();
+ }
+ }
+ } else if (message instanceof RpcFailure) {
+ RpcFailure resp = (RpcFailure) message;
+ RpcResponseCallback listener = outstandingRpcs.get(resp.requestId);
+ if (listener == null) {
+ logger.warn("Ignoring response for RPC {} from {} ({}) since it is not outstanding",
+ resp.requestId, remoteAddress, resp.errorString);
+ } else {
+ outstandingRpcs.remove(resp.requestId);
+ listener.onFailure(new RuntimeException(resp.errorString));
+ }
+ } else if (message instanceof StreamResponse) {
+ StreamResponse resp = (StreamResponse) message;
+ StreamCallback callback = streamCallbacks.poll();
+ if (callback != null) {
+ if (resp.byteCount > 0) {
+ StreamInterceptor interceptor = new StreamInterceptor(this, resp.streamId, resp.byteCount,
+ callback);
+ try {
+ TransportFrameDecoder frameDecoder = (TransportFrameDecoder)
+ channel.pipeline().get(TransportFrameDecoder.HANDLER_NAME);
+ frameDecoder.setInterceptor(interceptor);
+ streamActive = true;
+ } catch (Exception e) {
+ logger.error("Error installing stream handler.", e);
+ deactivateStream();
+ }
+ } else {
+ try {
+ callback.onComplete(resp.streamId);
+ } catch (Exception e) {
+ logger.warn("Error in stream handler onComplete().", e);
+ }
+ }
+ } else {
+ logger.error("Could not find callback for StreamResponse.");
+ }
+ } else if (message instanceof StreamFailure) {
+ StreamFailure resp = (StreamFailure) message;
+ StreamCallback callback = streamCallbacks.poll();
+ if (callback != null) {
+ try {
+ callback.onFailure(resp.streamId, new RuntimeException(resp.error));
+ } catch (IOException ioe) {
+ logger.warn("Error in stream failure handler.", ioe);
+ }
+ } else {
+ logger.warn("Stream failure with unknown callback: {}", resp.error);
+ }
+ } else {
+ throw new IllegalStateException("Unknown response type: " + message.type());
+ }
+ }
+
+ /** Returns total number of outstanding requests (fetch requests + rpcs) */
+ public int numOutstandingRequests() {
+ return outstandingFetches.size() + outstandingRpcs.size() + streamCallbacks.size() +
+ (streamActive ? 1 : 0);
+ }
+
+ /** Returns the time in nanoseconds of when the last request was sent out. */
+ public long getTimeOfLastRequestNs() {
+ return timeOfLastRequestNs.get();
+ }
+
+ /** Updates the time of the last request to the current system time. */
+ public void updateTimeOfLastRequest() {
+ timeOfLastRequestNs.set(System.nanoTime());
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/spark/blob/9e01dcc6/common/network-common/src/main/java/org/apache/spark/network/protocol/AbstractMessage.java
----------------------------------------------------------------------
diff --git a/common/network-common/src/main/java/org/apache/spark/network/protocol/AbstractMessage.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/AbstractMessage.java
new file mode 100644
index 0000000..2924218
--- /dev/null
+++ b/common/network-common/src/main/java/org/apache/spark/network/protocol/AbstractMessage.java
@@ -0,0 +1,54 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.protocol;
+
+import com.google.common.base.Objects;
+
+import org.apache.spark.network.buffer.ManagedBuffer;
+
+/**
+ * Abstract class for messages which optionally contain a body kept in a separate buffer.
+ */
+public abstract class AbstractMessage implements Message {
+ private final ManagedBuffer body;
+ private final boolean isBodyInFrame;
+
+ protected AbstractMessage() {
+ this(null, false);
+ }
+
+ protected AbstractMessage(ManagedBuffer body, boolean isBodyInFrame) {
+ this.body = body;
+ this.isBodyInFrame = isBodyInFrame;
+ }
+
+ @Override
+ public ManagedBuffer body() {
+ return body;
+ }
+
+ @Override
+ public boolean isBodyInFrame() {
+ return isBodyInFrame;
+ }
+
+ protected boolean equals(AbstractMessage other) {
+ return isBodyInFrame == other.isBodyInFrame && Objects.equal(body, other.body);
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/spark/blob/9e01dcc6/common/network-common/src/main/java/org/apache/spark/network/protocol/AbstractResponseMessage.java
----------------------------------------------------------------------
diff --git a/common/network-common/src/main/java/org/apache/spark/network/protocol/AbstractResponseMessage.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/AbstractResponseMessage.java
new file mode 100644
index 0000000..c362c92
--- /dev/null
+++ b/common/network-common/src/main/java/org/apache/spark/network/protocol/AbstractResponseMessage.java
@@ -0,0 +1,32 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.protocol;
+
+import org.apache.spark.network.buffer.ManagedBuffer;
+
+/**
+ * Abstract class for response messages.
+ */
+public abstract class AbstractResponseMessage extends AbstractMessage implements ResponseMessage {
+
+ protected AbstractResponseMessage(ManagedBuffer body, boolean isBodyInFrame) {
+ super(body, isBodyInFrame);
+ }
+
+ public abstract ResponseMessage createFailureResponse(String error);
+}
http://git-wip-us.apache.org/repos/asf/spark/blob/9e01dcc6/common/network-common/src/main/java/org/apache/spark/network/protocol/ChunkFetchFailure.java
----------------------------------------------------------------------
diff --git a/common/network-common/src/main/java/org/apache/spark/network/protocol/ChunkFetchFailure.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/ChunkFetchFailure.java
new file mode 100644
index 0000000..7b28a9a
--- /dev/null
+++ b/common/network-common/src/main/java/org/apache/spark/network/protocol/ChunkFetchFailure.java
@@ -0,0 +1,76 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.protocol;
+
+import com.google.common.base.Objects;
+import io.netty.buffer.ByteBuf;
+
+/**
+ * Response to {@link ChunkFetchRequest} when there is an error fetching the chunk.
+ */
+public final class ChunkFetchFailure extends AbstractMessage implements ResponseMessage {
+ public final StreamChunkId streamChunkId;
+ public final String errorString;
+
+ public ChunkFetchFailure(StreamChunkId streamChunkId, String errorString) {
+ this.streamChunkId = streamChunkId;
+ this.errorString = errorString;
+ }
+
+ @Override
+ public Type type() { return Type.ChunkFetchFailure; }
+
+ @Override
+ public int encodedLength() {
+ return streamChunkId.encodedLength() + Encoders.Strings.encodedLength(errorString);
+ }
+
+ @Override
+ public void encode(ByteBuf buf) {
+ streamChunkId.encode(buf);
+ Encoders.Strings.encode(buf, errorString);
+ }
+
+ public static ChunkFetchFailure decode(ByteBuf buf) {
+ StreamChunkId streamChunkId = StreamChunkId.decode(buf);
+ String errorString = Encoders.Strings.decode(buf);
+ return new ChunkFetchFailure(streamChunkId, errorString);
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hashCode(streamChunkId, errorString);
+ }
+
+ @Override
+ public boolean equals(Object other) {
+ if (other instanceof ChunkFetchFailure) {
+ ChunkFetchFailure o = (ChunkFetchFailure) other;
+ return streamChunkId.equals(o.streamChunkId) && errorString.equals(o.errorString);
+ }
+ return false;
+ }
+
+ @Override
+ public String toString() {
+ return Objects.toStringHelper(this)
+ .add("streamChunkId", streamChunkId)
+ .add("errorString", errorString)
+ .toString();
+ }
+}
http://git-wip-us.apache.org/repos/asf/spark/blob/9e01dcc6/common/network-common/src/main/java/org/apache/spark/network/protocol/ChunkFetchRequest.java
----------------------------------------------------------------------
diff --git a/common/network-common/src/main/java/org/apache/spark/network/protocol/ChunkFetchRequest.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/ChunkFetchRequest.java
new file mode 100644
index 0000000..26d063f
--- /dev/null
+++ b/common/network-common/src/main/java/org/apache/spark/network/protocol/ChunkFetchRequest.java
@@ -0,0 +1,71 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.protocol;
+
+import com.google.common.base.Objects;
+import io.netty.buffer.ByteBuf;
+
+/**
+ * Request to fetch a sequence of a single chunk of a stream. This will correspond to a single
+ * {@link org.apache.spark.network.protocol.ResponseMessage} (either success or failure).
+ */
+public final class ChunkFetchRequest extends AbstractMessage implements RequestMessage {
+ public final StreamChunkId streamChunkId;
+
+ public ChunkFetchRequest(StreamChunkId streamChunkId) {
+ this.streamChunkId = streamChunkId;
+ }
+
+ @Override
+ public Type type() { return Type.ChunkFetchRequest; }
+
+ @Override
+ public int encodedLength() {
+ return streamChunkId.encodedLength();
+ }
+
+ @Override
+ public void encode(ByteBuf buf) {
+ streamChunkId.encode(buf);
+ }
+
+ public static ChunkFetchRequest decode(ByteBuf buf) {
+ return new ChunkFetchRequest(StreamChunkId.decode(buf));
+ }
+
+ @Override
+ public int hashCode() {
+ return streamChunkId.hashCode();
+ }
+
+ @Override
+ public boolean equals(Object other) {
+ if (other instanceof ChunkFetchRequest) {
+ ChunkFetchRequest o = (ChunkFetchRequest) other;
+ return streamChunkId.equals(o.streamChunkId);
+ }
+ return false;
+ }
+
+ @Override
+ public String toString() {
+ return Objects.toStringHelper(this)
+ .add("streamChunkId", streamChunkId)
+ .toString();
+ }
+}
http://git-wip-us.apache.org/repos/asf/spark/blob/9e01dcc6/common/network-common/src/main/java/org/apache/spark/network/protocol/ChunkFetchSuccess.java
----------------------------------------------------------------------
diff --git a/common/network-common/src/main/java/org/apache/spark/network/protocol/ChunkFetchSuccess.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/ChunkFetchSuccess.java
new file mode 100644
index 0000000..94c2ac9
--- /dev/null
+++ b/common/network-common/src/main/java/org/apache/spark/network/protocol/ChunkFetchSuccess.java
@@ -0,0 +1,89 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.protocol;
+
+import com.google.common.base.Objects;
+import io.netty.buffer.ByteBuf;
+
+import org.apache.spark.network.buffer.ManagedBuffer;
+import org.apache.spark.network.buffer.NettyManagedBuffer;
+
+/**
+ * Response to {@link ChunkFetchRequest} when a chunk exists and has been successfully fetched.
+ *
+ * Note that the server-side encoding of this messages does NOT include the buffer itself, as this
+ * may be written by Netty in a more efficient manner (i.e., zero-copy write).
+ * Similarly, the client-side decoding will reuse the Netty ByteBuf as the buffer.
+ */
+public final class ChunkFetchSuccess extends AbstractResponseMessage {
+ public final StreamChunkId streamChunkId;
+
+ public ChunkFetchSuccess(StreamChunkId streamChunkId, ManagedBuffer buffer) {
+ super(buffer, true);
+ this.streamChunkId = streamChunkId;
+ }
+
+ @Override
+ public Type type() { return Type.ChunkFetchSuccess; }
+
+ @Override
+ public int encodedLength() {
+ return streamChunkId.encodedLength();
+ }
+
+ /** Encoding does NOT include 'buffer' itself. See {@link MessageEncoder}. */
+ @Override
+ public void encode(ByteBuf buf) {
+ streamChunkId.encode(buf);
+ }
+
+ @Override
+ public ResponseMessage createFailureResponse(String error) {
+ return new ChunkFetchFailure(streamChunkId, error);
+ }
+
+ /** Decoding uses the given ByteBuf as our data, and will retain() it. */
+ public static ChunkFetchSuccess decode(ByteBuf buf) {
+ StreamChunkId streamChunkId = StreamChunkId.decode(buf);
+ buf.retain();
+ NettyManagedBuffer managedBuf = new NettyManagedBuffer(buf.duplicate());
+ return new ChunkFetchSuccess(streamChunkId, managedBuf);
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hashCode(streamChunkId, body());
+ }
+
+ @Override
+ public boolean equals(Object other) {
+ if (other instanceof ChunkFetchSuccess) {
+ ChunkFetchSuccess o = (ChunkFetchSuccess) other;
+ return streamChunkId.equals(o.streamChunkId) && super.equals(o);
+ }
+ return false;
+ }
+
+ @Override
+ public String toString() {
+ return Objects.toStringHelper(this)
+ .add("streamChunkId", streamChunkId)
+ .add("buffer", body())
+ .toString();
+ }
+}
http://git-wip-us.apache.org/repos/asf/spark/blob/9e01dcc6/common/network-common/src/main/java/org/apache/spark/network/protocol/Encodable.java
----------------------------------------------------------------------
diff --git a/common/network-common/src/main/java/org/apache/spark/network/protocol/Encodable.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/Encodable.java
new file mode 100644
index 0000000..b4e2994
--- /dev/null
+++ b/common/network-common/src/main/java/org/apache/spark/network/protocol/Encodable.java
@@ -0,0 +1,41 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.protocol;
+
+import io.netty.buffer.ByteBuf;
+
+/**
+ * Interface for an object which can be encoded into a ByteBuf. Multiple Encodable objects are
+ * stored in a single, pre-allocated ByteBuf, so Encodables must also provide their length.
+ *
+ * Encodable objects should provide a static "decode(ByteBuf)" method which is invoked by
+ * {@link MessageDecoder}. During decoding, if the object uses the ByteBuf as its data (rather than
+ * just copying data from it), then you must retain() the ByteBuf.
+ *
+ * Additionally, when adding a new Encodable Message, add it to {@link Message.Type}.
+ */
+public interface Encodable {
+ /** Number of bytes of the encoded form of this object. */
+ int encodedLength();
+
+ /**
+ * Serializes this object by writing into the given ByteBuf.
+ * This method must write exactly encodedLength() bytes.
+ */
+ void encode(ByteBuf buf);
+}
http://git-wip-us.apache.org/repos/asf/spark/blob/9e01dcc6/common/network-common/src/main/java/org/apache/spark/network/protocol/Encoders.java
----------------------------------------------------------------------
diff --git a/common/network-common/src/main/java/org/apache/spark/network/protocol/Encoders.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/Encoders.java
new file mode 100644
index 0000000..9162d0b
--- /dev/null
+++ b/common/network-common/src/main/java/org/apache/spark/network/protocol/Encoders.java
@@ -0,0 +1,92 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.protocol;
+
+
+import com.google.common.base.Charsets;
+import io.netty.buffer.ByteBuf;
+
+/** Provides a canonical set of Encoders for simple types. */
+public class Encoders {
+
+ /** Strings are encoded with their length followed by UTF-8 bytes. */
+ public static class Strings {
+ public static int encodedLength(String s) {
+ return 4 + s.getBytes(Charsets.UTF_8).length;
+ }
+
+ public static void encode(ByteBuf buf, String s) {
+ byte[] bytes = s.getBytes(Charsets.UTF_8);
+ buf.writeInt(bytes.length);
+ buf.writeBytes(bytes);
+ }
+
+ public static String decode(ByteBuf buf) {
+ int length = buf.readInt();
+ byte[] bytes = new byte[length];
+ buf.readBytes(bytes);
+ return new String(bytes, Charsets.UTF_8);
+ }
+ }
+
+ /** Byte arrays are encoded with their length followed by bytes. */
+ public static class ByteArrays {
+ public static int encodedLength(byte[] arr) {
+ return 4 + arr.length;
+ }
+
+ public static void encode(ByteBuf buf, byte[] arr) {
+ buf.writeInt(arr.length);
+ buf.writeBytes(arr);
+ }
+
+ public static byte[] decode(ByteBuf buf) {
+ int length = buf.readInt();
+ byte[] bytes = new byte[length];
+ buf.readBytes(bytes);
+ return bytes;
+ }
+ }
+
+ /** String arrays are encoded with the number of strings followed by per-String encoding. */
+ public static class StringArrays {
+ public static int encodedLength(String[] strings) {
+ int totalLength = 4;
+ for (String s : strings) {
+ totalLength += Strings.encodedLength(s);
+ }
+ return totalLength;
+ }
+
+ public static void encode(ByteBuf buf, String[] strings) {
+ buf.writeInt(strings.length);
+ for (String s : strings) {
+ Strings.encode(buf, s);
+ }
+ }
+
+ public static String[] decode(ByteBuf buf) {
+ int numStrings = buf.readInt();
+ String[] strings = new String[numStrings];
+ for (int i = 0; i < strings.length; i ++) {
+ strings[i] = Strings.decode(buf);
+ }
+ return strings;
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/spark/blob/9e01dcc6/common/network-common/src/main/java/org/apache/spark/network/protocol/Message.java
----------------------------------------------------------------------
diff --git a/common/network-common/src/main/java/org/apache/spark/network/protocol/Message.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/Message.java
new file mode 100644
index 0000000..66f5b8b
--- /dev/null
+++ b/common/network-common/src/main/java/org/apache/spark/network/protocol/Message.java
@@ -0,0 +1,73 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.protocol;
+
+import io.netty.buffer.ByteBuf;
+
+import org.apache.spark.network.buffer.ManagedBuffer;
+
+/** An on-the-wire transmittable message. */
+public interface Message extends Encodable {
+ /** Used to identify this request type. */
+ Type type();
+
+ /** An optional body for the message. */
+ ManagedBuffer body();
+
+ /** Whether to include the body of the message in the same frame as the message. */
+ boolean isBodyInFrame();
+
+ /** Preceding every serialized Message is its type, which allows us to deserialize it. */
+ public static enum Type implements Encodable {
+ ChunkFetchRequest(0), ChunkFetchSuccess(1), ChunkFetchFailure(2),
+ RpcRequest(3), RpcResponse(4), RpcFailure(5),
+ StreamRequest(6), StreamResponse(7), StreamFailure(8),
+ OneWayMessage(9), User(-1);
+
+ private final byte id;
+
+ private Type(int id) {
+ assert id < 128 : "Cannot have more than 128 message types";
+ this.id = (byte) id;
+ }
+
+ public byte id() { return id; }
+
+ @Override public int encodedLength() { return 1; }
+
+ @Override public void encode(ByteBuf buf) { buf.writeByte(id); }
+
+ public static Type decode(ByteBuf buf) {
+ byte id = buf.readByte();
+ switch (id) {
+ case 0: return ChunkFetchRequest;
+ case 1: return ChunkFetchSuccess;
+ case 2: return ChunkFetchFailure;
+ case 3: return RpcRequest;
+ case 4: return RpcResponse;
+ case 5: return RpcFailure;
+ case 6: return StreamRequest;
+ case 7: return StreamResponse;
+ case 8: return StreamFailure;
+ case 9: return OneWayMessage;
+ case -1: throw new IllegalArgumentException("User type messages cannot be decoded.");
+ default: throw new IllegalArgumentException("Unknown message type: " + id);
+ }
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/spark/blob/9e01dcc6/common/network-common/src/main/java/org/apache/spark/network/protocol/MessageDecoder.java
----------------------------------------------------------------------
diff --git a/common/network-common/src/main/java/org/apache/spark/network/protocol/MessageDecoder.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/MessageDecoder.java
new file mode 100644
index 0000000..074780f
--- /dev/null
+++ b/common/network-common/src/main/java/org/apache/spark/network/protocol/MessageDecoder.java
@@ -0,0 +1,82 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.protocol;
+
+import java.util.List;
+
+import io.netty.buffer.ByteBuf;
+import io.netty.channel.ChannelHandler;
+import io.netty.channel.ChannelHandlerContext;
+import io.netty.handler.codec.MessageToMessageDecoder;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * Decoder used by the client side to encode server-to-client responses.
+ * This encoder is stateless so it is safe to be shared by multiple threads.
+ */
+@ChannelHandler.Sharable
+public final class MessageDecoder extends MessageToMessageDecoder<ByteBuf> {
+
+ private final Logger logger = LoggerFactory.getLogger(MessageDecoder.class);
+ @Override
+ public void decode(ChannelHandlerContext ctx, ByteBuf in, List<Object> out) {
+ Message.Type msgType = Message.Type.decode(in);
+ Message decoded = decode(msgType, in);
+ assert decoded.type() == msgType;
+ logger.trace("Received message " + msgType + ": " + decoded);
+ out.add(decoded);
+ }
+
+ private Message decode(Message.Type msgType, ByteBuf in) {
+ switch (msgType) {
+ case ChunkFetchRequest:
+ return ChunkFetchRequest.decode(in);
+
+ case ChunkFetchSuccess:
+ return ChunkFetchSuccess.decode(in);
+
+ case ChunkFetchFailure:
+ return ChunkFetchFailure.decode(in);
+
+ case RpcRequest:
+ return RpcRequest.decode(in);
+
+ case RpcResponse:
+ return RpcResponse.decode(in);
+
+ case RpcFailure:
+ return RpcFailure.decode(in);
+
+ case OneWayMessage:
+ return OneWayMessage.decode(in);
+
+ case StreamRequest:
+ return StreamRequest.decode(in);
+
+ case StreamResponse:
+ return StreamResponse.decode(in);
+
+ case StreamFailure:
+ return StreamFailure.decode(in);
+
+ default:
+ throw new IllegalArgumentException("Unexpected message type: " + msgType);
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/spark/blob/9e01dcc6/common/network-common/src/main/java/org/apache/spark/network/protocol/MessageEncoder.java
----------------------------------------------------------------------
diff --git a/common/network-common/src/main/java/org/apache/spark/network/protocol/MessageEncoder.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/MessageEncoder.java
new file mode 100644
index 0000000..664df57
--- /dev/null
+++ b/common/network-common/src/main/java/org/apache/spark/network/protocol/MessageEncoder.java
@@ -0,0 +1,93 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.protocol;
+
+import java.util.List;
+
+import io.netty.buffer.ByteBuf;
+import io.netty.channel.ChannelHandler;
+import io.netty.channel.ChannelHandlerContext;
+import io.netty.handler.codec.MessageToMessageEncoder;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * Encoder used by the server side to encode server-to-client responses.
+ * This encoder is stateless so it is safe to be shared by multiple threads.
+ */
+@ChannelHandler.Sharable
+public final class MessageEncoder extends MessageToMessageEncoder<Message> {
+
+ private final Logger logger = LoggerFactory.getLogger(MessageEncoder.class);
+
+ /***
+ * Encodes a Message by invoking its encode() method. For non-data messages, we will add one
+ * ByteBuf to 'out' containing the total frame length, the message type, and the message itself.
+ * In the case of a ChunkFetchSuccess, we will also add the ManagedBuffer corresponding to the
+ * data to 'out', in order to enable zero-copy transfer.
+ */
+ @Override
+ public void encode(ChannelHandlerContext ctx, Message in, List<Object> out) throws Exception {
+ Object body = null;
+ long bodyLength = 0;
+ boolean isBodyInFrame = false;
+
+ // If the message has a body, take it out to enable zero-copy transfer for the payload.
+ if (in.body() != null) {
+ try {
+ bodyLength = in.body().size();
+ body = in.body().convertToNetty();
+ isBodyInFrame = in.isBodyInFrame();
+ } catch (Exception e) {
+ in.body().release();
+ if (in instanceof AbstractResponseMessage) {
+ AbstractResponseMessage resp = (AbstractResponseMessage) in;
+ // Re-encode this message as a failure response.
+ String error = e.getMessage() != null ? e.getMessage() : "null";
+ logger.error(String.format("Error processing %s for client %s",
+ in, ctx.channel().remoteAddress()), e);
+ encode(ctx, resp.createFailureResponse(error), out);
+ } else {
+ throw e;
+ }
+ return;
+ }
+ }
+
+ Message.Type msgType = in.type();
+ // All messages have the frame length, message type, and message itself. The frame length
+ // may optionally include the length of the body data, depending on what message is being
+ // sent.
+ int headerLength = 8 + msgType.encodedLength() + in.encodedLength();
+ long frameLength = headerLength + (isBodyInFrame ? bodyLength : 0);
+ ByteBuf header = ctx.alloc().heapBuffer(headerLength);
+ header.writeLong(frameLength);
+ msgType.encode(header);
+ in.encode(header);
+ assert header.writableBytes() == 0;
+
+ if (body != null) {
+ // We transfer ownership of the reference on in.body() to MessageWithHeader.
+ // This reference will be freed when MessageWithHeader.deallocate() is called.
+ out.add(new MessageWithHeader(in.body(), header, body, bodyLength));
+ } else {
+ out.add(header);
+ }
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/spark/blob/9e01dcc6/common/network-common/src/main/java/org/apache/spark/network/protocol/MessageWithHeader.java
----------------------------------------------------------------------
diff --git a/common/network-common/src/main/java/org/apache/spark/network/protocol/MessageWithHeader.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/MessageWithHeader.java
new file mode 100644
index 0000000..66227f9
--- /dev/null
+++ b/common/network-common/src/main/java/org/apache/spark/network/protocol/MessageWithHeader.java
@@ -0,0 +1,135 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.protocol;
+
+import java.io.IOException;
+import java.nio.channels.WritableByteChannel;
+import javax.annotation.Nullable;
+
+import com.google.common.base.Preconditions;
+import io.netty.buffer.ByteBuf;
+import io.netty.channel.FileRegion;
+import io.netty.util.AbstractReferenceCounted;
+import io.netty.util.ReferenceCountUtil;
+
+import org.apache.spark.network.buffer.ManagedBuffer;
+
+/**
+ * A wrapper message that holds two separate pieces (a header and a body).
+ *
+ * The header must be a ByteBuf, while the body can be a ByteBuf or a FileRegion.
+ */
+class MessageWithHeader extends AbstractReferenceCounted implements FileRegion {
+
+ @Nullable private final ManagedBuffer managedBuffer;
+ private final ByteBuf header;
+ private final int headerLength;
+ private final Object body;
+ private final long bodyLength;
+ private long totalBytesTransferred;
+
+ /**
+ * Construct a new MessageWithHeader.
+ *
+ * @param managedBuffer the {@link ManagedBuffer} that the message body came from. This needs to
+ * be passed in so that the buffer can be freed when this message is
+ * deallocated. Ownership of the caller's reference to this buffer is
+ * transferred to this class, so if the caller wants to continue to use the
+ * ManagedBuffer in other messages then they will need to call retain() on
+ * it before passing it to this constructor. This may be null if and only if
+ * `body` is a {@link FileRegion}.
+ * @param header the message header.
+ * @param body the message body. Must be either a {@link ByteBuf} or a {@link FileRegion}.
+ * @param bodyLength the length of the message body, in bytes.
+ */
+ MessageWithHeader(
+ @Nullable ManagedBuffer managedBuffer,
+ ByteBuf header,
+ Object body,
+ long bodyLength) {
+ Preconditions.checkArgument(body instanceof ByteBuf || body instanceof FileRegion,
+ "Body must be a ByteBuf or a FileRegion.");
+ this.managedBuffer = managedBuffer;
+ this.header = header;
+ this.headerLength = header.readableBytes();
+ this.body = body;
+ this.bodyLength = bodyLength;
+ }
+
+ @Override
+ public long count() {
+ return headerLength + bodyLength;
+ }
+
+ @Override
+ public long position() {
+ return 0;
+ }
+
+ @Override
+ public long transfered() {
+ return totalBytesTransferred;
+ }
+
+ /**
+ * This code is more complicated than you would think because we might require multiple
+ * transferTo invocations in order to transfer a single MessageWithHeader to avoid busy waiting.
+ *
+ * The contract is that the caller will ensure position is properly set to the total number
+ * of bytes transferred so far (i.e. value returned by transfered()).
+ */
+ @Override
+ public long transferTo(final WritableByteChannel target, final long position) throws IOException {
+ Preconditions.checkArgument(position == totalBytesTransferred, "Invalid position.");
+ // Bytes written for header in this call.
+ long writtenHeader = 0;
+ if (header.readableBytes() > 0) {
+ writtenHeader = copyByteBuf(header, target);
+ totalBytesTransferred += writtenHeader;
+ if (header.readableBytes() > 0) {
+ return writtenHeader;
+ }
+ }
+
+ // Bytes written for body in this call.
+ long writtenBody = 0;
+ if (body instanceof FileRegion) {
+ writtenBody = ((FileRegion) body).transferTo(target, totalBytesTransferred - headerLength);
+ } else if (body instanceof ByteBuf) {
+ writtenBody = copyByteBuf((ByteBuf) body, target);
+ }
+ totalBytesTransferred += writtenBody;
+
+ return writtenHeader + writtenBody;
+ }
+
+ @Override
+ protected void deallocate() {
+ header.release();
+ ReferenceCountUtil.release(body);
+ if (managedBuffer != null) {
+ managedBuffer.release();
+ }
+ }
+
+ private int copyByteBuf(ByteBuf buf, WritableByteChannel target) throws IOException {
+ int written = target.write(buf.nioBuffer());
+ buf.skipBytes(written);
+ return written;
+ }
+}
http://git-wip-us.apache.org/repos/asf/spark/blob/9e01dcc6/common/network-common/src/main/java/org/apache/spark/network/protocol/OneWayMessage.java
----------------------------------------------------------------------
diff --git a/common/network-common/src/main/java/org/apache/spark/network/protocol/OneWayMessage.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/OneWayMessage.java
new file mode 100644
index 0000000..efe0470
--- /dev/null
+++ b/common/network-common/src/main/java/org/apache/spark/network/protocol/OneWayMessage.java
@@ -0,0 +1,80 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.protocol;
+
+import com.google.common.base.Objects;
+import io.netty.buffer.ByteBuf;
+import io.netty.buffer.Unpooled;
+
+import org.apache.spark.network.buffer.ManagedBuffer;
+import org.apache.spark.network.buffer.NettyManagedBuffer;
+
+/**
+ * A RPC that does not expect a reply, which is handled by a remote
+ * {@link org.apache.spark.network.server.RpcHandler}.
+ */
+public final class OneWayMessage extends AbstractMessage implements RequestMessage {
+
+ public OneWayMessage(ManagedBuffer body) {
+ super(body, true);
+ }
+
+ @Override
+ public Type type() { return Type.OneWayMessage; }
+
+ @Override
+ public int encodedLength() {
+ // The integer (a.k.a. the body size) is not really used, since that information is already
+ // encoded in the frame length. But this maintains backwards compatibility with versions of
+ // RpcRequest that use Encoders.ByteArrays.
+ return 4;
+ }
+
+ @Override
+ public void encode(ByteBuf buf) {
+ // See comment in encodedLength().
+ buf.writeInt((int) body().size());
+ }
+
+ public static OneWayMessage decode(ByteBuf buf) {
+ // See comment in encodedLength().
+ buf.readInt();
+ return new OneWayMessage(new NettyManagedBuffer(buf.retain()));
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hashCode(body());
+ }
+
+ @Override
+ public boolean equals(Object other) {
+ if (other instanceof OneWayMessage) {
+ OneWayMessage o = (OneWayMessage) other;
+ return super.equals(o);
+ }
+ return false;
+ }
+
+ @Override
+ public String toString() {
+ return Objects.toStringHelper(this)
+ .add("body", body())
+ .toString();
+ }
+}
http://git-wip-us.apache.org/repos/asf/spark/blob/9e01dcc6/common/network-common/src/main/java/org/apache/spark/network/protocol/RequestMessage.java
----------------------------------------------------------------------
diff --git a/common/network-common/src/main/java/org/apache/spark/network/protocol/RequestMessage.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/RequestMessage.java
new file mode 100644
index 0000000..31b15bb
--- /dev/null
+++ b/common/network-common/src/main/java/org/apache/spark/network/protocol/RequestMessage.java
@@ -0,0 +1,25 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.protocol;
+
+import org.apache.spark.network.protocol.Message;
+
+/** Messages from the client to the server. */
+public interface RequestMessage extends Message {
+ // token interface
+}
http://git-wip-us.apache.org/repos/asf/spark/blob/9e01dcc6/common/network-common/src/main/java/org/apache/spark/network/protocol/ResponseMessage.java
----------------------------------------------------------------------
diff --git a/common/network-common/src/main/java/org/apache/spark/network/protocol/ResponseMessage.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/ResponseMessage.java
new file mode 100644
index 0000000..6edffd1
--- /dev/null
+++ b/common/network-common/src/main/java/org/apache/spark/network/protocol/ResponseMessage.java
@@ -0,0 +1,25 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.protocol;
+
+import org.apache.spark.network.protocol.Message;
+
+/** Messages from the server to the client. */
+public interface ResponseMessage extends Message {
+ // token interface
+}
http://git-wip-us.apache.org/repos/asf/spark/blob/9e01dcc6/common/network-common/src/main/java/org/apache/spark/network/protocol/RpcFailure.java
----------------------------------------------------------------------
diff --git a/common/network-common/src/main/java/org/apache/spark/network/protocol/RpcFailure.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/RpcFailure.java
new file mode 100644
index 0000000..a76624e
--- /dev/null
+++ b/common/network-common/src/main/java/org/apache/spark/network/protocol/RpcFailure.java
@@ -0,0 +1,74 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.protocol;
+
+import com.google.common.base.Objects;
+import io.netty.buffer.ByteBuf;
+
+/** Response to {@link RpcRequest} for a failed RPC. */
+public final class RpcFailure extends AbstractMessage implements ResponseMessage {
+ public final long requestId;
+ public final String errorString;
+
+ public RpcFailure(long requestId, String errorString) {
+ this.requestId = requestId;
+ this.errorString = errorString;
+ }
+
+ @Override
+ public Type type() { return Type.RpcFailure; }
+
+ @Override
+ public int encodedLength() {
+ return 8 + Encoders.Strings.encodedLength(errorString);
+ }
+
+ @Override
+ public void encode(ByteBuf buf) {
+ buf.writeLong(requestId);
+ Encoders.Strings.encode(buf, errorString);
+ }
+
+ public static RpcFailure decode(ByteBuf buf) {
+ long requestId = buf.readLong();
+ String errorString = Encoders.Strings.decode(buf);
+ return new RpcFailure(requestId, errorString);
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hashCode(requestId, errorString);
+ }
+
+ @Override
+ public boolean equals(Object other) {
+ if (other instanceof RpcFailure) {
+ RpcFailure o = (RpcFailure) other;
+ return requestId == o.requestId && errorString.equals(o.errorString);
+ }
+ return false;
+ }
+
+ @Override
+ public String toString() {
+ return Objects.toStringHelper(this)
+ .add("requestId", requestId)
+ .add("errorString", errorString)
+ .toString();
+ }
+}
http://git-wip-us.apache.org/repos/asf/spark/blob/9e01dcc6/common/network-common/src/main/java/org/apache/spark/network/protocol/RpcRequest.java
----------------------------------------------------------------------
diff --git a/common/network-common/src/main/java/org/apache/spark/network/protocol/RpcRequest.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/RpcRequest.java
new file mode 100644
index 0000000..9621379
--- /dev/null
+++ b/common/network-common/src/main/java/org/apache/spark/network/protocol/RpcRequest.java
@@ -0,0 +1,87 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.protocol;
+
+import com.google.common.base.Objects;
+import io.netty.buffer.ByteBuf;
+import io.netty.buffer.Unpooled;
+
+import org.apache.spark.network.buffer.ManagedBuffer;
+import org.apache.spark.network.buffer.NettyManagedBuffer;
+
+/**
+ * A generic RPC which is handled by a remote {@link org.apache.spark.network.server.RpcHandler}.
+ * This will correspond to a single
+ * {@link org.apache.spark.network.protocol.ResponseMessage} (either success or failure).
+ */
+public final class RpcRequest extends AbstractMessage implements RequestMessage {
+ /** Used to link an RPC request with its response. */
+ public final long requestId;
+
+ public RpcRequest(long requestId, ManagedBuffer message) {
+ super(message, true);
+ this.requestId = requestId;
+ }
+
+ @Override
+ public Type type() { return Type.RpcRequest; }
+
+ @Override
+ public int encodedLength() {
+ // The integer (a.k.a. the body size) is not really used, since that information is already
+ // encoded in the frame length. But this maintains backwards compatibility with versions of
+ // RpcRequest that use Encoders.ByteArrays.
+ return 8 + 4;
+ }
+
+ @Override
+ public void encode(ByteBuf buf) {
+ buf.writeLong(requestId);
+ // See comment in encodedLength().
+ buf.writeInt((int) body().size());
+ }
+
+ public static RpcRequest decode(ByteBuf buf) {
+ long requestId = buf.readLong();
+ // See comment in encodedLength().
+ buf.readInt();
+ return new RpcRequest(requestId, new NettyManagedBuffer(buf.retain()));
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hashCode(requestId, body());
+ }
+
+ @Override
+ public boolean equals(Object other) {
+ if (other instanceof RpcRequest) {
+ RpcRequest o = (RpcRequest) other;
+ return requestId == o.requestId && super.equals(o);
+ }
+ return false;
+ }
+
+ @Override
+ public String toString() {
+ return Objects.toStringHelper(this)
+ .add("requestId", requestId)
+ .add("body", body())
+ .toString();
+ }
+}
http://git-wip-us.apache.org/repos/asf/spark/blob/9e01dcc6/common/network-common/src/main/java/org/apache/spark/network/protocol/RpcResponse.java
----------------------------------------------------------------------
diff --git a/common/network-common/src/main/java/org/apache/spark/network/protocol/RpcResponse.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/RpcResponse.java
new file mode 100644
index 0000000..bae866e
--- /dev/null
+++ b/common/network-common/src/main/java/org/apache/spark/network/protocol/RpcResponse.java
@@ -0,0 +1,87 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.protocol;
+
+import com.google.common.base.Objects;
+import io.netty.buffer.ByteBuf;
+import io.netty.buffer.Unpooled;
+
+import org.apache.spark.network.buffer.ManagedBuffer;
+import org.apache.spark.network.buffer.NettyManagedBuffer;
+
+/** Response to {@link RpcRequest} for a successful RPC. */
+public final class RpcResponse extends AbstractResponseMessage {
+ public final long requestId;
+
+ public RpcResponse(long requestId, ManagedBuffer message) {
+ super(message, true);
+ this.requestId = requestId;
+ }
+
+ @Override
+ public Type type() { return Type.RpcResponse; }
+
+ @Override
+ public int encodedLength() {
+ // The integer (a.k.a. the body size) is not really used, since that information is already
+ // encoded in the frame length. But this maintains backwards compatibility with versions of
+ // RpcRequest that use Encoders.ByteArrays.
+ return 8 + 4;
+ }
+
+ @Override
+ public void encode(ByteBuf buf) {
+ buf.writeLong(requestId);
+ // See comment in encodedLength().
+ buf.writeInt((int) body().size());
+ }
+
+ @Override
+ public ResponseMessage createFailureResponse(String error) {
+ return new RpcFailure(requestId, error);
+ }
+
+ public static RpcResponse decode(ByteBuf buf) {
+ long requestId = buf.readLong();
+ // See comment in encodedLength().
+ buf.readInt();
+ return new RpcResponse(requestId, new NettyManagedBuffer(buf.retain()));
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hashCode(requestId, body());
+ }
+
+ @Override
+ public boolean equals(Object other) {
+ if (other instanceof RpcResponse) {
+ RpcResponse o = (RpcResponse) other;
+ return requestId == o.requestId && super.equals(o);
+ }
+ return false;
+ }
+
+ @Override
+ public String toString() {
+ return Objects.toStringHelper(this)
+ .add("requestId", requestId)
+ .add("body", body())
+ .toString();
+ }
+}
http://git-wip-us.apache.org/repos/asf/spark/blob/9e01dcc6/common/network-common/src/main/java/org/apache/spark/network/protocol/StreamChunkId.java
----------------------------------------------------------------------
diff --git a/common/network-common/src/main/java/org/apache/spark/network/protocol/StreamChunkId.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/StreamChunkId.java
new file mode 100644
index 0000000..d46a263
--- /dev/null
+++ b/common/network-common/src/main/java/org/apache/spark/network/protocol/StreamChunkId.java
@@ -0,0 +1,73 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.protocol;
+
+import com.google.common.base.Objects;
+import io.netty.buffer.ByteBuf;
+
+/**
+* Encapsulates a request for a particular chunk of a stream.
+*/
+public final class StreamChunkId implements Encodable {
+ public final long streamId;
+ public final int chunkIndex;
+
+ public StreamChunkId(long streamId, int chunkIndex) {
+ this.streamId = streamId;
+ this.chunkIndex = chunkIndex;
+ }
+
+ @Override
+ public int encodedLength() {
+ return 8 + 4;
+ }
+
+ public void encode(ByteBuf buffer) {
+ buffer.writeLong(streamId);
+ buffer.writeInt(chunkIndex);
+ }
+
+ public static StreamChunkId decode(ByteBuf buffer) {
+ assert buffer.readableBytes() >= 8 + 4;
+ long streamId = buffer.readLong();
+ int chunkIndex = buffer.readInt();
+ return new StreamChunkId(streamId, chunkIndex);
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hashCode(streamId, chunkIndex);
+ }
+
+ @Override
+ public boolean equals(Object other) {
+ if (other instanceof StreamChunkId) {
+ StreamChunkId o = (StreamChunkId) other;
+ return streamId == o.streamId && chunkIndex == o.chunkIndex;
+ }
+ return false;
+ }
+
+ @Override
+ public String toString() {
+ return Objects.toStringHelper(this)
+ .add("streamId", streamId)
+ .add("chunkIndex", chunkIndex)
+ .toString();
+ }
+}
http://git-wip-us.apache.org/repos/asf/spark/blob/9e01dcc6/common/network-common/src/main/java/org/apache/spark/network/protocol/StreamFailure.java
----------------------------------------------------------------------
diff --git a/common/network-common/src/main/java/org/apache/spark/network/protocol/StreamFailure.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/StreamFailure.java
new file mode 100644
index 0000000..26747ee
--- /dev/null
+++ b/common/network-common/src/main/java/org/apache/spark/network/protocol/StreamFailure.java
@@ -0,0 +1,80 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.protocol;
+
+import com.google.common.base.Objects;
+import io.netty.buffer.ByteBuf;
+
+import org.apache.spark.network.buffer.ManagedBuffer;
+import org.apache.spark.network.buffer.NettyManagedBuffer;
+
+/**
+ * Message indicating an error when transferring a stream.
+ */
+public final class StreamFailure extends AbstractMessage implements ResponseMessage {
+ public final String streamId;
+ public final String error;
+
+ public StreamFailure(String streamId, String error) {
+ this.streamId = streamId;
+ this.error = error;
+ }
+
+ @Override
+ public Type type() { return Type.StreamFailure; }
+
+ @Override
+ public int encodedLength() {
+ return Encoders.Strings.encodedLength(streamId) + Encoders.Strings.encodedLength(error);
+ }
+
+ @Override
+ public void encode(ByteBuf buf) {
+ Encoders.Strings.encode(buf, streamId);
+ Encoders.Strings.encode(buf, error);
+ }
+
+ public static StreamFailure decode(ByteBuf buf) {
+ String streamId = Encoders.Strings.decode(buf);
+ String error = Encoders.Strings.decode(buf);
+ return new StreamFailure(streamId, error);
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hashCode(streamId, error);
+ }
+
+ @Override
+ public boolean equals(Object other) {
+ if (other instanceof StreamFailure) {
+ StreamFailure o = (StreamFailure) other;
+ return streamId.equals(o.streamId) && error.equals(o.error);
+ }
+ return false;
+ }
+
+ @Override
+ public String toString() {
+ return Objects.toStringHelper(this)
+ .add("streamId", streamId)
+ .add("error", error)
+ .toString();
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/spark/blob/9e01dcc6/common/network-common/src/main/java/org/apache/spark/network/protocol/StreamRequest.java
----------------------------------------------------------------------
diff --git a/common/network-common/src/main/java/org/apache/spark/network/protocol/StreamRequest.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/StreamRequest.java
new file mode 100644
index 0000000..35af5a8
--- /dev/null
+++ b/common/network-common/src/main/java/org/apache/spark/network/protocol/StreamRequest.java
@@ -0,0 +1,78 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.protocol;
+
+import com.google.common.base.Objects;
+import io.netty.buffer.ByteBuf;
+
+import org.apache.spark.network.buffer.ManagedBuffer;
+import org.apache.spark.network.buffer.NettyManagedBuffer;
+
+/**
+ * Request to stream data from the remote end.
+ * <p>
+ * The stream ID is an arbitrary string that needs to be negotiated between the two endpoints before
+ * the data can be streamed.
+ */
+public final class StreamRequest extends AbstractMessage implements RequestMessage {
+ public final String streamId;
+
+ public StreamRequest(String streamId) {
+ this.streamId = streamId;
+ }
+
+ @Override
+ public Type type() { return Type.StreamRequest; }
+
+ @Override
+ public int encodedLength() {
+ return Encoders.Strings.encodedLength(streamId);
+ }
+
+ @Override
+ public void encode(ByteBuf buf) {
+ Encoders.Strings.encode(buf, streamId);
+ }
+
+ public static StreamRequest decode(ByteBuf buf) {
+ String streamId = Encoders.Strings.decode(buf);
+ return new StreamRequest(streamId);
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hashCode(streamId);
+ }
+
+ @Override
+ public boolean equals(Object other) {
+ if (other instanceof StreamRequest) {
+ StreamRequest o = (StreamRequest) other;
+ return streamId.equals(o.streamId);
+ }
+ return false;
+ }
+
+ @Override
+ public String toString() {
+ return Objects.toStringHelper(this)
+ .add("streamId", streamId)
+ .toString();
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/spark/blob/9e01dcc6/common/network-common/src/main/java/org/apache/spark/network/protocol/StreamResponse.java
----------------------------------------------------------------------
diff --git a/common/network-common/src/main/java/org/apache/spark/network/protocol/StreamResponse.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/StreamResponse.java
new file mode 100644
index 0000000..51b8999
--- /dev/null
+++ b/common/network-common/src/main/java/org/apache/spark/network/protocol/StreamResponse.java
@@ -0,0 +1,92 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.protocol;
+
+import com.google.common.base.Objects;
+import io.netty.buffer.ByteBuf;
+
+import org.apache.spark.network.buffer.ManagedBuffer;
+import org.apache.spark.network.buffer.NettyManagedBuffer;
+
+/**
+ * Response to {@link StreamRequest} when the stream has been successfully opened.
+ * <p>
+ * Note the message itself does not contain the stream data. That is written separately by the
+ * sender. The receiver is expected to set a temporary channel handler that will consume the
+ * number of bytes this message says the stream has.
+ */
+public final class StreamResponse extends AbstractResponseMessage {
+ public final String streamId;
+ public final long byteCount;
+
+ public StreamResponse(String streamId, long byteCount, ManagedBuffer buffer) {
+ super(buffer, false);
+ this.streamId = streamId;
+ this.byteCount = byteCount;
+ }
+
+ @Override
+ public Type type() { return Type.StreamResponse; }
+
+ @Override
+ public int encodedLength() {
+ return 8 + Encoders.Strings.encodedLength(streamId);
+ }
+
+ /** Encoding does NOT include 'buffer' itself. See {@link MessageEncoder}. */
+ @Override
+ public void encode(ByteBuf buf) {
+ Encoders.Strings.encode(buf, streamId);
+ buf.writeLong(byteCount);
+ }
+
+ @Override
+ public ResponseMessage createFailureResponse(String error) {
+ return new StreamFailure(streamId, error);
+ }
+
+ public static StreamResponse decode(ByteBuf buf) {
+ String streamId = Encoders.Strings.decode(buf);
+ long byteCount = buf.readLong();
+ return new StreamResponse(streamId, byteCount, null);
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hashCode(byteCount, streamId, body());
+ }
+
+ @Override
+ public boolean equals(Object other) {
+ if (other instanceof StreamResponse) {
+ StreamResponse o = (StreamResponse) other;
+ return byteCount == o.byteCount && streamId.equals(o.streamId);
+ }
+ return false;
+ }
+
+ @Override
+ public String toString() {
+ return Objects.toStringHelper(this)
+ .add("streamId", streamId)
+ .add("byteCount", byteCount)
+ .add("body", body())
+ .toString();
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/spark/blob/9e01dcc6/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java
----------------------------------------------------------------------
diff --git a/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java b/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java
new file mode 100644
index 0000000..6838103
--- /dev/null
+++ b/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java
@@ -0,0 +1,109 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.sasl;
+
+import java.io.IOException;
+import java.nio.ByteBuffer;
+import javax.security.sasl.Sasl;
+import javax.security.sasl.SaslException;
+
+import io.netty.buffer.ByteBuf;
+import io.netty.buffer.Unpooled;
+import io.netty.channel.Channel;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.spark.network.client.TransportClient;
+import org.apache.spark.network.client.TransportClientBootstrap;
+import org.apache.spark.network.util.JavaUtils;
+import org.apache.spark.network.util.TransportConf;
+
+/**
+ * Bootstraps a {@link TransportClient} by performing SASL authentication on the connection. The
+ * server should be setup with a {@link SaslRpcHandler} with matching keys for the given appId.
+ */
+public class SaslClientBootstrap implements TransportClientBootstrap {
+ private final Logger logger = LoggerFactory.getLogger(SaslClientBootstrap.class);
+
+ private final boolean encrypt;
+ private final TransportConf conf;
+ private final String appId;
+ private final SecretKeyHolder secretKeyHolder;
+
+ public SaslClientBootstrap(TransportConf conf, String appId, SecretKeyHolder secretKeyHolder) {
+ this(conf, appId, secretKeyHolder, false);
+ }
+
+ public SaslClientBootstrap(
+ TransportConf conf,
+ String appId,
+ SecretKeyHolder secretKeyHolder,
+ boolean encrypt) {
+ this.conf = conf;
+ this.appId = appId;
+ this.secretKeyHolder = secretKeyHolder;
+ this.encrypt = encrypt;
+ }
+
+ /**
+ * Performs SASL authentication by sending a token, and then proceeding with the SASL
+ * challenge-response tokens until we either successfully authenticate or throw an exception
+ * due to mismatch.
+ */
+ @Override
+ public void doBootstrap(TransportClient client, Channel channel) {
+ SparkSaslClient saslClient = new SparkSaslClient(appId, secretKeyHolder, encrypt);
+ try {
+ byte[] payload = saslClient.firstToken();
+
+ while (!saslClient.isComplete()) {
+ SaslMessage msg = new SaslMessage(appId, payload);
+ ByteBuf buf = Unpooled.buffer(msg.encodedLength() + (int) msg.body().size());
+ msg.encode(buf);
+ buf.writeBytes(msg.body().nioByteBuffer());
+
+ ByteBuffer response = client.sendRpcSync(buf.nioBuffer(), conf.saslRTTimeoutMs());
+ payload = saslClient.response(JavaUtils.bufferToArray(response));
+ }
+
+ client.setClientId(appId);
+
+ if (encrypt) {
+ if (!SparkSaslServer.QOP_AUTH_CONF.equals(saslClient.getNegotiatedProperty(Sasl.QOP))) {
+ throw new RuntimeException(
+ new SaslException("Encryption requests by negotiated non-encrypted connection."));
+ }
+ SaslEncryption.addToChannel(channel, saslClient, conf.maxSaslEncryptedBlockSize());
+ saslClient = null;
+ logger.debug("Channel {} configured for SASL encryption.", client);
+ }
+ } catch (IOException ioe) {
+ throw new RuntimeException(ioe);
+ } finally {
+ if (saslClient != null) {
+ try {
+ // Once authentication is complete, the server will trust all remaining communication.
+ saslClient.dispose();
+ } catch (RuntimeException e) {
+ logger.error("Error while disposing SASL client", e);
+ }
+ }
+ }
+ }
+
+}
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org