You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by rx...@apache.org on 2016/02/29 02:25:27 UTC
[11/14] spark git commit: [SPARK-13529][BUILD] Move network/* modules
into common/network-*
http://git-wip-us.apache.org/repos/asf/spark/blob/9e01dcc6/common/network-common/src/main/java/org/apache/spark/network/util/ByteUnit.java
----------------------------------------------------------------------
diff --git a/common/network-common/src/main/java/org/apache/spark/network/util/ByteUnit.java b/common/network-common/src/main/java/org/apache/spark/network/util/ByteUnit.java
new file mode 100644
index 0000000..a2f0183
--- /dev/null
+++ b/common/network-common/src/main/java/org/apache/spark/network/util/ByteUnit.java
@@ -0,0 +1,67 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.network.util;
+
+public enum ByteUnit {
+ BYTE (1),
+ KiB (1024L),
+ MiB ((long) Math.pow(1024L, 2L)),
+ GiB ((long) Math.pow(1024L, 3L)),
+ TiB ((long) Math.pow(1024L, 4L)),
+ PiB ((long) Math.pow(1024L, 5L));
+
+ private ByteUnit(long multiplier) {
+ this.multiplier = multiplier;
+ }
+
+ // Interpret the provided number (d) with suffix (u) as this unit type.
+ // E.g. KiB.interpret(1, MiB) interprets 1MiB as its KiB representation = 1024k
+ public long convertFrom(long d, ByteUnit u) {
+ return u.convertTo(d, this);
+ }
+
+ // Convert the provided number (d) interpreted as this unit type to unit type (u).
+ public long convertTo(long d, ByteUnit u) {
+ if (multiplier > u.multiplier) {
+ long ratio = multiplier / u.multiplier;
+ if (Long.MAX_VALUE / ratio < d) {
+ throw new IllegalArgumentException("Conversion of " + d + " exceeds Long.MAX_VALUE in "
+ + name() + ". Try a larger unit (e.g. MiB instead of KiB)");
+ }
+ return d * ratio;
+ } else {
+ // Perform operations in this order to avoid potential overflow
+ // when computing d * multiplier
+ return d / (u.multiplier / multiplier);
+ }
+ }
+
+ public double toBytes(long d) {
+ if (d < 0) {
+ throw new IllegalArgumentException("Negative size value. Size must be positive: " + d);
+ }
+ return d * multiplier;
+ }
+
+ public long toKiB(long d) { return convertTo(d, KiB); }
+ public long toMiB(long d) { return convertTo(d, MiB); }
+ public long toGiB(long d) { return convertTo(d, GiB); }
+ public long toTiB(long d) { return convertTo(d, TiB); }
+ public long toPiB(long d) { return convertTo(d, PiB); }
+
+ private final long multiplier;
+}
http://git-wip-us.apache.org/repos/asf/spark/blob/9e01dcc6/common/network-common/src/main/java/org/apache/spark/network/util/ConfigProvider.java
----------------------------------------------------------------------
diff --git a/common/network-common/src/main/java/org/apache/spark/network/util/ConfigProvider.java b/common/network-common/src/main/java/org/apache/spark/network/util/ConfigProvider.java
new file mode 100644
index 0000000..d944d9d
--- /dev/null
+++ b/common/network-common/src/main/java/org/apache/spark/network/util/ConfigProvider.java
@@ -0,0 +1,52 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.util;
+
+import java.util.NoSuchElementException;
+
+/**
+ * Provides a mechanism for constructing a {@link TransportConf} using some sort of configuration.
+ */
+public abstract class ConfigProvider {
+ /** Obtains the value of the given config, throws NoSuchElementException if it doesn't exist. */
+ public abstract String get(String name);
+
+ public String get(String name, String defaultValue) {
+ try {
+ return get(name);
+ } catch (NoSuchElementException e) {
+ return defaultValue;
+ }
+ }
+
+ public int getInt(String name, int defaultValue) {
+ return Integer.parseInt(get(name, Integer.toString(defaultValue)));
+ }
+
+ public long getLong(String name, long defaultValue) {
+ return Long.parseLong(get(name, Long.toString(defaultValue)));
+ }
+
+ public double getDouble(String name, double defaultValue) {
+ return Double.parseDouble(get(name, Double.toString(defaultValue)));
+ }
+
+ public boolean getBoolean(String name, boolean defaultValue) {
+ return Boolean.parseBoolean(get(name, Boolean.toString(defaultValue)));
+ }
+}
http://git-wip-us.apache.org/repos/asf/spark/blob/9e01dcc6/common/network-common/src/main/java/org/apache/spark/network/util/IOMode.java
----------------------------------------------------------------------
diff --git a/common/network-common/src/main/java/org/apache/spark/network/util/IOMode.java b/common/network-common/src/main/java/org/apache/spark/network/util/IOMode.java
new file mode 100644
index 0000000..6b208d9
--- /dev/null
+++ b/common/network-common/src/main/java/org/apache/spark/network/util/IOMode.java
@@ -0,0 +1,27 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.util;
+
+/**
+ * Selector for which form of low-level IO we should use.
+ * NIO is always available, while EPOLL is only available on Linux.
+ * AUTO is used to select EPOLL if it's available, or NIO otherwise.
+ */
+public enum IOMode {
+ NIO, EPOLL
+}
http://git-wip-us.apache.org/repos/asf/spark/blob/9e01dcc6/common/network-common/src/main/java/org/apache/spark/network/util/JavaUtils.java
----------------------------------------------------------------------
diff --git a/common/network-common/src/main/java/org/apache/spark/network/util/JavaUtils.java b/common/network-common/src/main/java/org/apache/spark/network/util/JavaUtils.java
new file mode 100644
index 0000000..b3d8e0c
--- /dev/null
+++ b/common/network-common/src/main/java/org/apache/spark/network/util/JavaUtils.java
@@ -0,0 +1,303 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.util;
+
+import java.io.Closeable;
+import java.io.File;
+import java.io.IOException;
+import java.nio.ByteBuffer;
+import java.util.concurrent.TimeUnit;
+import java.util.regex.Matcher;
+import java.util.regex.Pattern;
+
+import com.google.common.base.Charsets;
+import com.google.common.base.Preconditions;
+import com.google.common.collect.ImmutableMap;
+import io.netty.buffer.Unpooled;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * General utilities available in the network package. Many of these are sourced from Spark's
+ * own Utils, just accessible within this package.
+ */
+public class JavaUtils {
+ private static final Logger logger = LoggerFactory.getLogger(JavaUtils.class);
+
+ /**
+ * Define a default value for driver memory here since this value is referenced across the code
+ * base and nearly all files already use Utils.scala
+ */
+ public static final long DEFAULT_DRIVER_MEM_MB = 1024;
+
+ /** Closes the given object, ignoring IOExceptions. */
+ public static void closeQuietly(Closeable closeable) {
+ try {
+ if (closeable != null) {
+ closeable.close();
+ }
+ } catch (IOException e) {
+ logger.error("IOException should not have been thrown.", e);
+ }
+ }
+
+ /** Returns a hash consistent with Spark's Utils.nonNegativeHash(). */
+ public static int nonNegativeHash(Object obj) {
+ if (obj == null) { return 0; }
+ int hash = obj.hashCode();
+ return hash != Integer.MIN_VALUE ? Math.abs(hash) : 0;
+ }
+
+ /**
+ * Convert the given string to a byte buffer. The resulting buffer can be
+ * converted back to the same string through {@link #bytesToString(ByteBuffer)}.
+ */
+ public static ByteBuffer stringToBytes(String s) {
+ return Unpooled.wrappedBuffer(s.getBytes(Charsets.UTF_8)).nioBuffer();
+ }
+
+ /**
+ * Convert the given byte buffer to a string. The resulting string can be
+ * converted back to the same byte buffer through {@link #stringToBytes(String)}.
+ */
+ public static String bytesToString(ByteBuffer b) {
+ return Unpooled.wrappedBuffer(b).toString(Charsets.UTF_8);
+ }
+
+ /*
+ * Delete a file or directory and its contents recursively.
+ * Don't follow directories if they are symlinks.
+ * Throws an exception if deletion is unsuccessful.
+ */
+ public static void deleteRecursively(File file) throws IOException {
+ if (file == null) { return; }
+
+ if (file.isDirectory() && !isSymlink(file)) {
+ IOException savedIOException = null;
+ for (File child : listFilesSafely(file)) {
+ try {
+ deleteRecursively(child);
+ } catch (IOException e) {
+ // In case of multiple exceptions, only last one will be thrown
+ savedIOException = e;
+ }
+ }
+ if (savedIOException != null) {
+ throw savedIOException;
+ }
+ }
+
+ boolean deleted = file.delete();
+ // Delete can also fail if the file simply did not exist.
+ if (!deleted && file.exists()) {
+ throw new IOException("Failed to delete: " + file.getAbsolutePath());
+ }
+ }
+
+ private static File[] listFilesSafely(File file) throws IOException {
+ if (file.exists()) {
+ File[] files = file.listFiles();
+ if (files == null) {
+ throw new IOException("Failed to list files for dir: " + file);
+ }
+ return files;
+ } else {
+ return new File[0];
+ }
+ }
+
+ private static boolean isSymlink(File file) throws IOException {
+ Preconditions.checkNotNull(file);
+ File fileInCanonicalDir = null;
+ if (file.getParent() == null) {
+ fileInCanonicalDir = file;
+ } else {
+ fileInCanonicalDir = new File(file.getParentFile().getCanonicalFile(), file.getName());
+ }
+ return !fileInCanonicalDir.getCanonicalFile().equals(fileInCanonicalDir.getAbsoluteFile());
+ }
+
+ private static final ImmutableMap<String, TimeUnit> timeSuffixes =
+ ImmutableMap.<String, TimeUnit>builder()
+ .put("us", TimeUnit.MICROSECONDS)
+ .put("ms", TimeUnit.MILLISECONDS)
+ .put("s", TimeUnit.SECONDS)
+ .put("m", TimeUnit.MINUTES)
+ .put("min", TimeUnit.MINUTES)
+ .put("h", TimeUnit.HOURS)
+ .put("d", TimeUnit.DAYS)
+ .build();
+
+ private static final ImmutableMap<String, ByteUnit> byteSuffixes =
+ ImmutableMap.<String, ByteUnit>builder()
+ .put("b", ByteUnit.BYTE)
+ .put("k", ByteUnit.KiB)
+ .put("kb", ByteUnit.KiB)
+ .put("m", ByteUnit.MiB)
+ .put("mb", ByteUnit.MiB)
+ .put("g", ByteUnit.GiB)
+ .put("gb", ByteUnit.GiB)
+ .put("t", ByteUnit.TiB)
+ .put("tb", ByteUnit.TiB)
+ .put("p", ByteUnit.PiB)
+ .put("pb", ByteUnit.PiB)
+ .build();
+
+ /**
+ * Convert a passed time string (e.g. 50s, 100ms, or 250us) to a time count for
+ * internal use. If no suffix is provided a direct conversion is attempted.
+ */
+ private static long parseTimeString(String str, TimeUnit unit) {
+ String lower = str.toLowerCase().trim();
+
+ try {
+ Matcher m = Pattern.compile("(-?[0-9]+)([a-z]+)?").matcher(lower);
+ if (!m.matches()) {
+ throw new NumberFormatException("Failed to parse time string: " + str);
+ }
+
+ long val = Long.parseLong(m.group(1));
+ String suffix = m.group(2);
+
+ // Check for invalid suffixes
+ if (suffix != null && !timeSuffixes.containsKey(suffix)) {
+ throw new NumberFormatException("Invalid suffix: \"" + suffix + "\"");
+ }
+
+ // If suffix is valid use that, otherwise none was provided and use the default passed
+ return unit.convert(val, suffix != null ? timeSuffixes.get(suffix) : unit);
+ } catch (NumberFormatException e) {
+ String timeError = "Time must be specified as seconds (s), " +
+ "milliseconds (ms), microseconds (us), minutes (m or min), hour (h), or day (d). " +
+ "E.g. 50s, 100ms, or 250us.";
+
+ throw new NumberFormatException(timeError + "\n" + e.getMessage());
+ }
+ }
+
+ /**
+ * Convert a time parameter such as (50s, 100ms, or 250us) to milliseconds for internal use. If
+ * no suffix is provided, the passed number is assumed to be in ms.
+ */
+ public static long timeStringAsMs(String str) {
+ return parseTimeString(str, TimeUnit.MILLISECONDS);
+ }
+
+ /**
+ * Convert a time parameter such as (50s, 100ms, or 250us) to seconds for internal use. If
+ * no suffix is provided, the passed number is assumed to be in seconds.
+ */
+ public static long timeStringAsSec(String str) {
+ return parseTimeString(str, TimeUnit.SECONDS);
+ }
+
+ /**
+ * Convert a passed byte string (e.g. 50b, 100kb, or 250mb) to a ByteUnit for
+ * internal use. If no suffix is provided a direct conversion of the provided default is
+ * attempted.
+ */
+ private static long parseByteString(String str, ByteUnit unit) {
+ String lower = str.toLowerCase().trim();
+
+ try {
+ Matcher m = Pattern.compile("([0-9]+)([a-z]+)?").matcher(lower);
+ Matcher fractionMatcher = Pattern.compile("([0-9]+\\.[0-9]+)([a-z]+)?").matcher(lower);
+
+ if (m.matches()) {
+ long val = Long.parseLong(m.group(1));
+ String suffix = m.group(2);
+
+ // Check for invalid suffixes
+ if (suffix != null && !byteSuffixes.containsKey(suffix)) {
+ throw new NumberFormatException("Invalid suffix: \"" + suffix + "\"");
+ }
+
+ // If suffix is valid use that, otherwise none was provided and use the default passed
+ return unit.convertFrom(val, suffix != null ? byteSuffixes.get(suffix) : unit);
+ } else if (fractionMatcher.matches()) {
+ throw new NumberFormatException("Fractional values are not supported. Input was: "
+ + fractionMatcher.group(1));
+ } else {
+ throw new NumberFormatException("Failed to parse byte string: " + str);
+ }
+
+ } catch (NumberFormatException e) {
+ String timeError = "Size must be specified as bytes (b), " +
+ "kibibytes (k), mebibytes (m), gibibytes (g), tebibytes (t), or pebibytes(p). " +
+ "E.g. 50b, 100k, or 250m.";
+
+ throw new NumberFormatException(timeError + "\n" + e.getMessage());
+ }
+ }
+
+ /**
+ * Convert a passed byte string (e.g. 50b, 100k, or 250m) to bytes for
+ * internal use.
+ *
+ * If no suffix is provided, the passed number is assumed to be in bytes.
+ */
+ public static long byteStringAsBytes(String str) {
+ return parseByteString(str, ByteUnit.BYTE);
+ }
+
+ /**
+ * Convert a passed byte string (e.g. 50b, 100k, or 250m) to kibibytes for
+ * internal use.
+ *
+ * If no suffix is provided, the passed number is assumed to be in kibibytes.
+ */
+ public static long byteStringAsKb(String str) {
+ return parseByteString(str, ByteUnit.KiB);
+ }
+
+ /**
+ * Convert a passed byte string (e.g. 50b, 100k, or 250m) to mebibytes for
+ * internal use.
+ *
+ * If no suffix is provided, the passed number is assumed to be in mebibytes.
+ */
+ public static long byteStringAsMb(String str) {
+ return parseByteString(str, ByteUnit.MiB);
+ }
+
+ /**
+ * Convert a passed byte string (e.g. 50b, 100k, or 250m) to gibibytes for
+ * internal use.
+ *
+ * If no suffix is provided, the passed number is assumed to be in gibibytes.
+ */
+ public static long byteStringAsGb(String str) {
+ return parseByteString(str, ByteUnit.GiB);
+ }
+
+ /**
+ * Returns a byte array with the buffer's contents, trying to avoid copying the data if
+ * possible.
+ */
+ public static byte[] bufferToArray(ByteBuffer buffer) {
+ if (buffer.hasArray() && buffer.arrayOffset() == 0 &&
+ buffer.array().length == buffer.remaining()) {
+ return buffer.array();
+ } else {
+ byte[] bytes = new byte[buffer.remaining()];
+ buffer.get(bytes);
+ return bytes;
+ }
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/spark/blob/9e01dcc6/common/network-common/src/main/java/org/apache/spark/network/util/LimitedInputStream.java
----------------------------------------------------------------------
diff --git a/common/network-common/src/main/java/org/apache/spark/network/util/LimitedInputStream.java b/common/network-common/src/main/java/org/apache/spark/network/util/LimitedInputStream.java
new file mode 100644
index 0000000..922c37a
--- /dev/null
+++ b/common/network-common/src/main/java/org/apache/spark/network/util/LimitedInputStream.java
@@ -0,0 +1,105 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/*
+ * Based on LimitedInputStream.java from Google Guava
+ *
+ * Copyright (C) 2007 The Guava Authors
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.util;
+
+import java.io.FilterInputStream;
+import java.io.IOException;
+import java.io.InputStream;
+
+import com.google.common.base.Preconditions;
+
+/**
+ * Wraps a {@link InputStream}, limiting the number of bytes which can be read.
+ *
+ * This code is from Guava's 14.0 source code, because there is no compatible way to
+ * use this functionality in both a Guava 11 environment and a Guava >14 environment.
+ */
+public final class LimitedInputStream extends FilterInputStream {
+ private long left;
+ private long mark = -1;
+
+ public LimitedInputStream(InputStream in, long limit) {
+ super(in);
+ Preconditions.checkNotNull(in);
+ Preconditions.checkArgument(limit >= 0, "limit must be non-negative");
+ left = limit;
+ }
+ @Override public int available() throws IOException {
+ return (int) Math.min(in.available(), left);
+ }
+ // it's okay to mark even if mark isn't supported, as reset won't work
+ @Override public synchronized void mark(int readLimit) {
+ in.mark(readLimit);
+ mark = left;
+ }
+ @Override public int read() throws IOException {
+ if (left == 0) {
+ return -1;
+ }
+ int result = in.read();
+ if (result != -1) {
+ --left;
+ }
+ return result;
+ }
+ @Override public int read(byte[] b, int off, int len) throws IOException {
+ if (left == 0) {
+ return -1;
+ }
+ len = (int) Math.min(len, left);
+ int result = in.read(b, off, len);
+ if (result != -1) {
+ left -= result;
+ }
+ return result;
+ }
+ @Override public synchronized void reset() throws IOException {
+ if (!in.markSupported()) {
+ throw new IOException("Mark not supported");
+ }
+ if (mark == -1) {
+ throw new IOException("Mark not set");
+ }
+ in.reset();
+ left = mark;
+ }
+ @Override public long skip(long n) throws IOException {
+ n = Math.min(n, left);
+ long skipped = in.skip(n);
+ left -= skipped;
+ return skipped;
+ }
+}
http://git-wip-us.apache.org/repos/asf/spark/blob/9e01dcc6/common/network-common/src/main/java/org/apache/spark/network/util/MapConfigProvider.java
----------------------------------------------------------------------
diff --git a/common/network-common/src/main/java/org/apache/spark/network/util/MapConfigProvider.java b/common/network-common/src/main/java/org/apache/spark/network/util/MapConfigProvider.java
new file mode 100644
index 0000000..668d235
--- /dev/null
+++ b/common/network-common/src/main/java/org/apache/spark/network/util/MapConfigProvider.java
@@ -0,0 +1,41 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.util;
+
+import com.google.common.collect.Maps;
+
+import java.util.Map;
+import java.util.NoSuchElementException;
+
+/** ConfigProvider based on a Map (copied in the constructor). */
+public class MapConfigProvider extends ConfigProvider {
+ private final Map<String, String> config;
+
+ public MapConfigProvider(Map<String, String> config) {
+ this.config = Maps.newHashMap(config);
+ }
+
+ @Override
+ public String get(String name) {
+ String value = config.get(name);
+ if (value == null) {
+ throw new NoSuchElementException(name);
+ }
+ return value;
+ }
+}
http://git-wip-us.apache.org/repos/asf/spark/blob/9e01dcc6/common/network-common/src/main/java/org/apache/spark/network/util/NettyUtils.java
----------------------------------------------------------------------
diff --git a/common/network-common/src/main/java/org/apache/spark/network/util/NettyUtils.java b/common/network-common/src/main/java/org/apache/spark/network/util/NettyUtils.java
new file mode 100644
index 0000000..caa7260
--- /dev/null
+++ b/common/network-common/src/main/java/org/apache/spark/network/util/NettyUtils.java
@@ -0,0 +1,139 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.util;
+
+import java.lang.reflect.Field;
+import java.util.concurrent.ThreadFactory;
+
+import com.google.common.util.concurrent.ThreadFactoryBuilder;
+import io.netty.buffer.PooledByteBufAllocator;
+import io.netty.channel.Channel;
+import io.netty.channel.EventLoopGroup;
+import io.netty.channel.ServerChannel;
+import io.netty.channel.epoll.EpollEventLoopGroup;
+import io.netty.channel.epoll.EpollServerSocketChannel;
+import io.netty.channel.epoll.EpollSocketChannel;
+import io.netty.channel.nio.NioEventLoopGroup;
+import io.netty.channel.socket.nio.NioServerSocketChannel;
+import io.netty.channel.socket.nio.NioSocketChannel;
+import io.netty.handler.codec.ByteToMessageDecoder;
+import io.netty.handler.codec.LengthFieldBasedFrameDecoder;
+import io.netty.util.internal.PlatformDependent;
+
+/**
+ * Utilities for creating various Netty constructs based on whether we're using EPOLL or NIO.
+ */
+public class NettyUtils {
+ /** Creates a new ThreadFactory which prefixes each thread with the given name. */
+ public static ThreadFactory createThreadFactory(String threadPoolPrefix) {
+ return new ThreadFactoryBuilder()
+ .setDaemon(true)
+ .setNameFormat(threadPoolPrefix + "-%d")
+ .build();
+ }
+
+ /** Creates a Netty EventLoopGroup based on the IOMode. */
+ public static EventLoopGroup createEventLoop(IOMode mode, int numThreads, String threadPrefix) {
+ ThreadFactory threadFactory = createThreadFactory(threadPrefix);
+
+ switch (mode) {
+ case NIO:
+ return new NioEventLoopGroup(numThreads, threadFactory);
+ case EPOLL:
+ return new EpollEventLoopGroup(numThreads, threadFactory);
+ default:
+ throw new IllegalArgumentException("Unknown io mode: " + mode);
+ }
+ }
+
+ /** Returns the correct (client) SocketChannel class based on IOMode. */
+ public static Class<? extends Channel> getClientChannelClass(IOMode mode) {
+ switch (mode) {
+ case NIO:
+ return NioSocketChannel.class;
+ case EPOLL:
+ return EpollSocketChannel.class;
+ default:
+ throw new IllegalArgumentException("Unknown io mode: " + mode);
+ }
+ }
+
+ /** Returns the correct ServerSocketChannel class based on IOMode. */
+ public static Class<? extends ServerChannel> getServerChannelClass(IOMode mode) {
+ switch (mode) {
+ case NIO:
+ return NioServerSocketChannel.class;
+ case EPOLL:
+ return EpollServerSocketChannel.class;
+ default:
+ throw new IllegalArgumentException("Unknown io mode: " + mode);
+ }
+ }
+
+ /**
+ * Creates a LengthFieldBasedFrameDecoder where the first 8 bytes are the length of the frame.
+ * This is used before all decoders.
+ */
+ public static TransportFrameDecoder createFrameDecoder() {
+ return new TransportFrameDecoder();
+ }
+
+ /** Returns the remote address on the channel or "<unknown remote>" if none exists. */
+ public static String getRemoteAddress(Channel channel) {
+ if (channel != null && channel.remoteAddress() != null) {
+ return channel.remoteAddress().toString();
+ }
+ return "<unknown remote>";
+ }
+
+ /**
+ * Create a pooled ByteBuf allocator but disables the thread-local cache. Thread-local caches
+ * are disabled for TransportClients because the ByteBufs are allocated by the event loop thread,
+ * but released by the executor thread rather than the event loop thread. Those thread-local
+ * caches actually delay the recycling of buffers, leading to larger memory usage.
+ */
+ public static PooledByteBufAllocator createPooledByteBufAllocator(
+ boolean allowDirectBufs,
+ boolean allowCache,
+ int numCores) {
+ if (numCores == 0) {
+ numCores = Runtime.getRuntime().availableProcessors();
+ }
+ return new PooledByteBufAllocator(
+ allowDirectBufs && PlatformDependent.directBufferPreferred(),
+ Math.min(getPrivateStaticField("DEFAULT_NUM_HEAP_ARENA"), numCores),
+ Math.min(getPrivateStaticField("DEFAULT_NUM_DIRECT_ARENA"), allowDirectBufs ? numCores : 0),
+ getPrivateStaticField("DEFAULT_PAGE_SIZE"),
+ getPrivateStaticField("DEFAULT_MAX_ORDER"),
+ allowCache ? getPrivateStaticField("DEFAULT_TINY_CACHE_SIZE") : 0,
+ allowCache ? getPrivateStaticField("DEFAULT_SMALL_CACHE_SIZE") : 0,
+ allowCache ? getPrivateStaticField("DEFAULT_NORMAL_CACHE_SIZE") : 0
+ );
+ }
+
+ /** Used to get defaults from Netty's private static fields. */
+ private static int getPrivateStaticField(String name) {
+ try {
+ Field f = PooledByteBufAllocator.DEFAULT.getClass().getDeclaredField(name);
+ f.setAccessible(true);
+ return f.getInt(null);
+ } catch (Exception e) {
+ throw new RuntimeException(e);
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/spark/blob/9e01dcc6/common/network-common/src/main/java/org/apache/spark/network/util/SystemPropertyConfigProvider.java
----------------------------------------------------------------------
diff --git a/common/network-common/src/main/java/org/apache/spark/network/util/SystemPropertyConfigProvider.java b/common/network-common/src/main/java/org/apache/spark/network/util/SystemPropertyConfigProvider.java
new file mode 100644
index 0000000..5f20b70
--- /dev/null
+++ b/common/network-common/src/main/java/org/apache/spark/network/util/SystemPropertyConfigProvider.java
@@ -0,0 +1,34 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.util;
+
+import java.util.NoSuchElementException;
+
+import org.apache.spark.network.util.ConfigProvider;
+
+/** Uses System properties to obtain config values. */
+public class SystemPropertyConfigProvider extends ConfigProvider {
+ @Override
+ public String get(String name) {
+ String value = System.getProperty(name);
+ if (value == null) {
+ throw new NoSuchElementException(name);
+ }
+ return value;
+ }
+}
http://git-wip-us.apache.org/repos/asf/spark/blob/9e01dcc6/common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java
----------------------------------------------------------------------
diff --git a/common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java b/common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java
new file mode 100644
index 0000000..9f030da
--- /dev/null
+++ b/common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java
@@ -0,0 +1,169 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.util;
+
+import com.google.common.primitives.Ints;
+
+/**
+ * A central location that tracks all the settings we expose to users.
+ */
+public class TransportConf {
+
+ private final String SPARK_NETWORK_IO_MODE_KEY;
+ private final String SPARK_NETWORK_IO_PREFERDIRECTBUFS_KEY;
+ private final String SPARK_NETWORK_IO_CONNECTIONTIMEOUT_KEY;
+ private final String SPARK_NETWORK_IO_BACKLOG_KEY;
+ private final String SPARK_NETWORK_IO_NUMCONNECTIONSPERPEER_KEY;
+ private final String SPARK_NETWORK_IO_SERVERTHREADS_KEY;
+ private final String SPARK_NETWORK_IO_CLIENTTHREADS_KEY;
+ private final String SPARK_NETWORK_IO_RECEIVEBUFFER_KEY;
+ private final String SPARK_NETWORK_IO_SENDBUFFER_KEY;
+ private final String SPARK_NETWORK_SASL_TIMEOUT_KEY;
+ private final String SPARK_NETWORK_IO_MAXRETRIES_KEY;
+ private final String SPARK_NETWORK_IO_RETRYWAIT_KEY;
+ private final String SPARK_NETWORK_IO_LAZYFD_KEY;
+
+ private final ConfigProvider conf;
+
+ private final String module;
+
+ public TransportConf(String module, ConfigProvider conf) {
+ this.module = module;
+ this.conf = conf;
+ SPARK_NETWORK_IO_MODE_KEY = getConfKey("io.mode");
+ SPARK_NETWORK_IO_PREFERDIRECTBUFS_KEY = getConfKey("io.preferDirectBufs");
+ SPARK_NETWORK_IO_CONNECTIONTIMEOUT_KEY = getConfKey("io.connectionTimeout");
+ SPARK_NETWORK_IO_BACKLOG_KEY = getConfKey("io.backLog");
+ SPARK_NETWORK_IO_NUMCONNECTIONSPERPEER_KEY = getConfKey("io.numConnectionsPerPeer");
+ SPARK_NETWORK_IO_SERVERTHREADS_KEY = getConfKey("io.serverThreads");
+ SPARK_NETWORK_IO_CLIENTTHREADS_KEY = getConfKey("io.clientThreads");
+ SPARK_NETWORK_IO_RECEIVEBUFFER_KEY = getConfKey("io.receiveBuffer");
+ SPARK_NETWORK_IO_SENDBUFFER_KEY = getConfKey("io.sendBuffer");
+ SPARK_NETWORK_SASL_TIMEOUT_KEY = getConfKey("sasl.timeout");
+ SPARK_NETWORK_IO_MAXRETRIES_KEY = getConfKey("io.maxRetries");
+ SPARK_NETWORK_IO_RETRYWAIT_KEY = getConfKey("io.retryWait");
+ SPARK_NETWORK_IO_LAZYFD_KEY = getConfKey("io.lazyFD");
+ }
+
+ private String getConfKey(String suffix) {
+ return "spark." + module + "." + suffix;
+ }
+
+ /** IO mode: nio or epoll */
+ public String ioMode() { return conf.get(SPARK_NETWORK_IO_MODE_KEY, "NIO").toUpperCase(); }
+
+ /** If true, we will prefer allocating off-heap byte buffers within Netty. */
+ public boolean preferDirectBufs() {
+ return conf.getBoolean(SPARK_NETWORK_IO_PREFERDIRECTBUFS_KEY, true);
+ }
+
+ /** Connect timeout in milliseconds. Default 120 secs. */
+ public int connectionTimeoutMs() {
+ long defaultNetworkTimeoutS = JavaUtils.timeStringAsSec(
+ conf.get("spark.network.timeout", "120s"));
+ long defaultTimeoutMs = JavaUtils.timeStringAsSec(
+ conf.get(SPARK_NETWORK_IO_CONNECTIONTIMEOUT_KEY, defaultNetworkTimeoutS + "s")) * 1000;
+ return (int) defaultTimeoutMs;
+ }
+
+ /** Number of concurrent connections between two nodes for fetching data. */
+ public int numConnectionsPerPeer() {
+ return conf.getInt(SPARK_NETWORK_IO_NUMCONNECTIONSPERPEER_KEY, 1);
+ }
+
+ /** Requested maximum length of the queue of incoming connections. Default -1 for no backlog. */
+ public int backLog() { return conf.getInt(SPARK_NETWORK_IO_BACKLOG_KEY, -1); }
+
+ /** Number of threads used in the server thread pool. Default to 0, which is 2x#cores. */
+ public int serverThreads() { return conf.getInt(SPARK_NETWORK_IO_SERVERTHREADS_KEY, 0); }
+
+ /** Number of threads used in the client thread pool. Default to 0, which is 2x#cores. */
+ public int clientThreads() { return conf.getInt(SPARK_NETWORK_IO_CLIENTTHREADS_KEY, 0); }
+
+ /**
+ * Receive buffer size (SO_RCVBUF).
+ * Note: the optimal size for receive buffer and send buffer should be
+ * latency * network_bandwidth.
+ * Assuming latency = 1ms, network_bandwidth = 10Gbps
+ * buffer size should be ~ 1.25MB
+ */
+ public int receiveBuf() { return conf.getInt(SPARK_NETWORK_IO_RECEIVEBUFFER_KEY, -1); }
+
+ /** Send buffer size (SO_SNDBUF). */
+ public int sendBuf() { return conf.getInt(SPARK_NETWORK_IO_SENDBUFFER_KEY, -1); }
+
+ /** Timeout for a single round trip of SASL token exchange, in milliseconds. */
+ public int saslRTTimeoutMs() {
+ return (int) JavaUtils.timeStringAsSec(conf.get(SPARK_NETWORK_SASL_TIMEOUT_KEY, "30s")) * 1000;
+ }
+
+ /**
+ * Max number of times we will try IO exceptions (such as connection timeouts) per request.
+ * If set to 0, we will not do any retries.
+ */
+ public int maxIORetries() { return conf.getInt(SPARK_NETWORK_IO_MAXRETRIES_KEY, 3); }
+
+ /**
+ * Time (in milliseconds) that we will wait in order to perform a retry after an IOException.
+ * Only relevant if maxIORetries > 0.
+ */
+ public int ioRetryWaitTimeMs() {
+ return (int) JavaUtils.timeStringAsSec(conf.get(SPARK_NETWORK_IO_RETRYWAIT_KEY, "5s")) * 1000;
+ }
+
+ /**
+ * Minimum size of a block that we should start using memory map rather than reading in through
+ * normal IO operations. This prevents Spark from memory mapping very small blocks. In general,
+ * memory mapping has high overhead for blocks close to or below the page size of the OS.
+ */
+ public int memoryMapBytes() {
+ return Ints.checkedCast(JavaUtils.byteStringAsBytes(
+ conf.get("spark.storage.memoryMapThreshold", "2m")));
+ }
+
+ /**
+ * Whether to initialize FileDescriptor lazily or not. If true, file descriptors are
+ * created only when data is going to be transferred. This can reduce the number of open files.
+ */
+ public boolean lazyFileDescriptor() {
+ return conf.getBoolean(SPARK_NETWORK_IO_LAZYFD_KEY, true);
+ }
+
+ /**
+ * Maximum number of retries when binding to a port before giving up.
+ */
+ public int portMaxRetries() {
+ return conf.getInt("spark.port.maxRetries", 16);
+ }
+
+ /**
+ * Maximum number of bytes to be encrypted at a time when SASL encryption is enabled.
+ */
+ public int maxSaslEncryptedBlockSize() {
+ return Ints.checkedCast(JavaUtils.byteStringAsBytes(
+ conf.get("spark.network.sasl.maxEncryptedBlockSize", "64k")));
+ }
+
+ /**
+ * Whether the server should enforce encryption on SASL-authenticated connections.
+ */
+ public boolean saslServerAlwaysEncrypt() {
+ return conf.getBoolean("spark.network.sasl.serverAlwaysEncrypt", false);
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/spark/blob/9e01dcc6/common/network-common/src/main/java/org/apache/spark/network/util/TransportFrameDecoder.java
----------------------------------------------------------------------
diff --git a/common/network-common/src/main/java/org/apache/spark/network/util/TransportFrameDecoder.java b/common/network-common/src/main/java/org/apache/spark/network/util/TransportFrameDecoder.java
new file mode 100644
index 0000000..a466c72
--- /dev/null
+++ b/common/network-common/src/main/java/org/apache/spark/network/util/TransportFrameDecoder.java
@@ -0,0 +1,227 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.util;
+
+import java.util.Iterator;
+import java.util.LinkedList;
+
+import com.google.common.base.Preconditions;
+import io.netty.buffer.ByteBuf;
+import io.netty.buffer.CompositeByteBuf;
+import io.netty.buffer.Unpooled;
+import io.netty.channel.ChannelHandlerContext;
+import io.netty.channel.ChannelInboundHandlerAdapter;
+
+/**
+ * A customized frame decoder that allows intercepting raw data.
+ * <p>
+ * This behaves like Netty's frame decoder (with harcoded parameters that match this library's
+ * needs), except it allows an interceptor to be installed to read data directly before it's
+ * framed.
+ * <p>
+ * Unlike Netty's frame decoder, each frame is dispatched to child handlers as soon as it's
+ * decoded, instead of building as many frames as the current buffer allows and dispatching
+ * all of them. This allows a child handler to install an interceptor if needed.
+ * <p>
+ * If an interceptor is installed, framing stops, and data is instead fed directly to the
+ * interceptor. When the interceptor indicates that it doesn't need to read any more data,
+ * framing resumes. Interceptors should not hold references to the data buffers provided
+ * to their handle() method.
+ */
+public class TransportFrameDecoder extends ChannelInboundHandlerAdapter {
+
+ public static final String HANDLER_NAME = "frameDecoder";
+ private static final int LENGTH_SIZE = 8;
+ private static final int MAX_FRAME_SIZE = Integer.MAX_VALUE;
+ private static final int UNKNOWN_FRAME_SIZE = -1;
+
+ private final LinkedList<ByteBuf> buffers = new LinkedList<>();
+ private final ByteBuf frameLenBuf = Unpooled.buffer(LENGTH_SIZE, LENGTH_SIZE);
+
+ private long totalSize = 0;
+ private long nextFrameSize = UNKNOWN_FRAME_SIZE;
+ private volatile Interceptor interceptor;
+
+ @Override
+ public void channelRead(ChannelHandlerContext ctx, Object data) throws Exception {
+ ByteBuf in = (ByteBuf) data;
+ buffers.add(in);
+ totalSize += in.readableBytes();
+
+ while (!buffers.isEmpty()) {
+ // First, feed the interceptor, and if it's still, active, try again.
+ if (interceptor != null) {
+ ByteBuf first = buffers.getFirst();
+ int available = first.readableBytes();
+ if (feedInterceptor(first)) {
+ assert !first.isReadable() : "Interceptor still active but buffer has data.";
+ }
+
+ int read = available - first.readableBytes();
+ if (read == available) {
+ buffers.removeFirst().release();
+ }
+ totalSize -= read;
+ } else {
+ // Interceptor is not active, so try to decode one frame.
+ ByteBuf frame = decodeNext();
+ if (frame == null) {
+ break;
+ }
+ ctx.fireChannelRead(frame);
+ }
+ }
+ }
+
+ private long decodeFrameSize() {
+ if (nextFrameSize != UNKNOWN_FRAME_SIZE || totalSize < LENGTH_SIZE) {
+ return nextFrameSize;
+ }
+
+ // We know there's enough data. If the first buffer contains all the data, great. Otherwise,
+ // hold the bytes for the frame length in a composite buffer until we have enough data to read
+ // the frame size. Normally, it should be rare to need more than one buffer to read the frame
+ // size.
+ ByteBuf first = buffers.getFirst();
+ if (first.readableBytes() >= LENGTH_SIZE) {
+ nextFrameSize = first.readLong() - LENGTH_SIZE;
+ totalSize -= LENGTH_SIZE;
+ if (!first.isReadable()) {
+ buffers.removeFirst().release();
+ }
+ return nextFrameSize;
+ }
+
+ while (frameLenBuf.readableBytes() < LENGTH_SIZE) {
+ ByteBuf next = buffers.getFirst();
+ int toRead = Math.min(next.readableBytes(), LENGTH_SIZE - frameLenBuf.readableBytes());
+ frameLenBuf.writeBytes(next, toRead);
+ if (!next.isReadable()) {
+ buffers.removeFirst().release();
+ }
+ }
+
+ nextFrameSize = frameLenBuf.readLong() - LENGTH_SIZE;
+ totalSize -= LENGTH_SIZE;
+ frameLenBuf.clear();
+ return nextFrameSize;
+ }
+
+ private ByteBuf decodeNext() throws Exception {
+ long frameSize = decodeFrameSize();
+ if (frameSize == UNKNOWN_FRAME_SIZE || totalSize < frameSize) {
+ return null;
+ }
+
+ // Reset size for next frame.
+ nextFrameSize = UNKNOWN_FRAME_SIZE;
+
+ Preconditions.checkArgument(frameSize < MAX_FRAME_SIZE, "Too large frame: %s", frameSize);
+ Preconditions.checkArgument(frameSize > 0, "Frame length should be positive: %s", frameSize);
+
+ // If the first buffer holds the entire frame, return it.
+ int remaining = (int) frameSize;
+ if (buffers.getFirst().readableBytes() >= remaining) {
+ return nextBufferForFrame(remaining);
+ }
+
+ // Otherwise, create a composite buffer.
+ CompositeByteBuf frame = buffers.getFirst().alloc().compositeBuffer();
+ while (remaining > 0) {
+ ByteBuf next = nextBufferForFrame(remaining);
+ remaining -= next.readableBytes();
+ frame.addComponent(next).writerIndex(frame.writerIndex() + next.readableBytes());
+ }
+ assert remaining == 0;
+ return frame;
+ }
+
+ /**
+ * Takes the first buffer in the internal list, and either adjust it to fit in the frame
+ * (by taking a slice out of it) or remove it from the internal list.
+ */
+ private ByteBuf nextBufferForFrame(int bytesToRead) {
+ ByteBuf buf = buffers.getFirst();
+ ByteBuf frame;
+
+ if (buf.readableBytes() > bytesToRead) {
+ frame = buf.retain().readSlice(bytesToRead);
+ totalSize -= bytesToRead;
+ } else {
+ frame = buf;
+ buffers.removeFirst();
+ totalSize -= frame.readableBytes();
+ }
+
+ return frame;
+ }
+
+ @Override
+ public void channelInactive(ChannelHandlerContext ctx) throws Exception {
+ for (ByteBuf b : buffers) {
+ b.release();
+ }
+ if (interceptor != null) {
+ interceptor.channelInactive();
+ }
+ frameLenBuf.release();
+ super.channelInactive(ctx);
+ }
+
+ @Override
+ public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
+ if (interceptor != null) {
+ interceptor.exceptionCaught(cause);
+ }
+ super.exceptionCaught(ctx, cause);
+ }
+
+ public void setInterceptor(Interceptor interceptor) {
+ Preconditions.checkState(this.interceptor == null, "Already have an interceptor.");
+ this.interceptor = interceptor;
+ }
+
+ /**
+ * @return Whether the interceptor is still active after processing the data.
+ */
+ private boolean feedInterceptor(ByteBuf buf) throws Exception {
+ if (interceptor != null && !interceptor.handle(buf)) {
+ interceptor = null;
+ }
+ return interceptor != null;
+ }
+
+ public static interface Interceptor {
+
+ /**
+ * Handles data received from the remote end.
+ *
+ * @param data Buffer containing data.
+ * @return "true" if the interceptor expects more data, "false" to uninstall the interceptor.
+ */
+ boolean handle(ByteBuf data) throws Exception;
+
+ /** Called if an exception is thrown in the channel pipeline. */
+ void exceptionCaught(Throwable cause) throws Exception;
+
+ /** Called if the channel is closed and the interceptor is still installed. */
+ void channelInactive() throws Exception;
+
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/spark/blob/9e01dcc6/common/network-common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java
----------------------------------------------------------------------
diff --git a/common/network-common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java b/common/network-common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java
new file mode 100644
index 0000000..70c849d
--- /dev/null
+++ b/common/network-common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java
@@ -0,0 +1,244 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network;
+
+import java.io.File;
+import java.io.RandomAccessFile;
+import java.nio.ByteBuffer;
+import java.util.Collections;
+import java.util.HashSet;
+import java.util.LinkedList;
+import java.util.List;
+import java.util.Random;
+import java.util.Set;
+import java.util.concurrent.Semaphore;
+import java.util.concurrent.TimeUnit;
+
+import com.google.common.collect.Lists;
+import com.google.common.collect.Sets;
+import com.google.common.io.Closeables;
+import org.junit.AfterClass;
+import org.junit.BeforeClass;
+import org.junit.Test;
+
+import static org.junit.Assert.*;
+
+import org.apache.spark.network.buffer.FileSegmentManagedBuffer;
+import org.apache.spark.network.buffer.ManagedBuffer;
+import org.apache.spark.network.buffer.NioManagedBuffer;
+import org.apache.spark.network.client.ChunkReceivedCallback;
+import org.apache.spark.network.client.RpcResponseCallback;
+import org.apache.spark.network.client.TransportClient;
+import org.apache.spark.network.client.TransportClientFactory;
+import org.apache.spark.network.server.RpcHandler;
+import org.apache.spark.network.server.TransportServer;
+import org.apache.spark.network.server.StreamManager;
+import org.apache.spark.network.util.SystemPropertyConfigProvider;
+import org.apache.spark.network.util.TransportConf;
+
+public class ChunkFetchIntegrationSuite {
+ static final long STREAM_ID = 1;
+ static final int BUFFER_CHUNK_INDEX = 0;
+ static final int FILE_CHUNK_INDEX = 1;
+
+ static TransportServer server;
+ static TransportClientFactory clientFactory;
+ static StreamManager streamManager;
+ static File testFile;
+
+ static ManagedBuffer bufferChunk;
+ static ManagedBuffer fileChunk;
+
+ private TransportConf transportConf;
+
+ @BeforeClass
+ public static void setUp() throws Exception {
+ int bufSize = 100000;
+ final ByteBuffer buf = ByteBuffer.allocate(bufSize);
+ for (int i = 0; i < bufSize; i ++) {
+ buf.put((byte) i);
+ }
+ buf.flip();
+ bufferChunk = new NioManagedBuffer(buf);
+
+ testFile = File.createTempFile("shuffle-test-file", "txt");
+ testFile.deleteOnExit();
+ RandomAccessFile fp = new RandomAccessFile(testFile, "rw");
+ boolean shouldSuppressIOException = true;
+ try {
+ byte[] fileContent = new byte[1024];
+ new Random().nextBytes(fileContent);
+ fp.write(fileContent);
+ shouldSuppressIOException = false;
+ } finally {
+ Closeables.close(fp, shouldSuppressIOException);
+ }
+
+ final TransportConf conf = new TransportConf("shuffle", new SystemPropertyConfigProvider());
+ fileChunk = new FileSegmentManagedBuffer(conf, testFile, 10, testFile.length() - 25);
+
+ streamManager = new StreamManager() {
+ @Override
+ public ManagedBuffer getChunk(long streamId, int chunkIndex) {
+ assertEquals(STREAM_ID, streamId);
+ if (chunkIndex == BUFFER_CHUNK_INDEX) {
+ return new NioManagedBuffer(buf);
+ } else if (chunkIndex == FILE_CHUNK_INDEX) {
+ return new FileSegmentManagedBuffer(conf, testFile, 10, testFile.length() - 25);
+ } else {
+ throw new IllegalArgumentException("Invalid chunk index: " + chunkIndex);
+ }
+ }
+ };
+ RpcHandler handler = new RpcHandler() {
+ @Override
+ public void receive(
+ TransportClient client,
+ ByteBuffer message,
+ RpcResponseCallback callback) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public StreamManager getStreamManager() {
+ return streamManager;
+ }
+ };
+ TransportContext context = new TransportContext(conf, handler);
+ server = context.createServer();
+ clientFactory = context.createClientFactory();
+ }
+
+ @AfterClass
+ public static void tearDown() {
+ bufferChunk.release();
+ server.close();
+ clientFactory.close();
+ testFile.delete();
+ }
+
+ class FetchResult {
+ public Set<Integer> successChunks;
+ public Set<Integer> failedChunks;
+ public List<ManagedBuffer> buffers;
+
+ public void releaseBuffers() {
+ for (ManagedBuffer buffer : buffers) {
+ buffer.release();
+ }
+ }
+ }
+
+ private FetchResult fetchChunks(List<Integer> chunkIndices) throws Exception {
+ TransportClient client = clientFactory.createClient(TestUtils.getLocalHost(), server.getPort());
+ final Semaphore sem = new Semaphore(0);
+
+ final FetchResult res = new FetchResult();
+ res.successChunks = Collections.synchronizedSet(new HashSet<Integer>());
+ res.failedChunks = Collections.synchronizedSet(new HashSet<Integer>());
+ res.buffers = Collections.synchronizedList(new LinkedList<ManagedBuffer>());
+
+ ChunkReceivedCallback callback = new ChunkReceivedCallback() {
+ @Override
+ public void onSuccess(int chunkIndex, ManagedBuffer buffer) {
+ buffer.retain();
+ res.successChunks.add(chunkIndex);
+ res.buffers.add(buffer);
+ sem.release();
+ }
+
+ @Override
+ public void onFailure(int chunkIndex, Throwable e) {
+ res.failedChunks.add(chunkIndex);
+ sem.release();
+ }
+ };
+
+ for (int chunkIndex : chunkIndices) {
+ client.fetchChunk(STREAM_ID, chunkIndex, callback);
+ }
+ if (!sem.tryAcquire(chunkIndices.size(), 5, TimeUnit.SECONDS)) {
+ fail("Timeout getting response from the server");
+ }
+ client.close();
+ return res;
+ }
+
+ @Test
+ public void fetchBufferChunk() throws Exception {
+ FetchResult res = fetchChunks(Lists.newArrayList(BUFFER_CHUNK_INDEX));
+ assertEquals(res.successChunks, Sets.newHashSet(BUFFER_CHUNK_INDEX));
+ assertTrue(res.failedChunks.isEmpty());
+ assertBufferListsEqual(res.buffers, Lists.newArrayList(bufferChunk));
+ res.releaseBuffers();
+ }
+
+ @Test
+ public void fetchFileChunk() throws Exception {
+ FetchResult res = fetchChunks(Lists.newArrayList(FILE_CHUNK_INDEX));
+ assertEquals(res.successChunks, Sets.newHashSet(FILE_CHUNK_INDEX));
+ assertTrue(res.failedChunks.isEmpty());
+ assertBufferListsEqual(res.buffers, Lists.newArrayList(fileChunk));
+ res.releaseBuffers();
+ }
+
+ @Test
+ public void fetchNonExistentChunk() throws Exception {
+ FetchResult res = fetchChunks(Lists.newArrayList(12345));
+ assertTrue(res.successChunks.isEmpty());
+ assertEquals(res.failedChunks, Sets.newHashSet(12345));
+ assertTrue(res.buffers.isEmpty());
+ }
+
+ @Test
+ public void fetchBothChunks() throws Exception {
+ FetchResult res = fetchChunks(Lists.newArrayList(BUFFER_CHUNK_INDEX, FILE_CHUNK_INDEX));
+ assertEquals(res.successChunks, Sets.newHashSet(BUFFER_CHUNK_INDEX, FILE_CHUNK_INDEX));
+ assertTrue(res.failedChunks.isEmpty());
+ assertBufferListsEqual(res.buffers, Lists.newArrayList(bufferChunk, fileChunk));
+ res.releaseBuffers();
+ }
+
+ @Test
+ public void fetchChunkAndNonExistent() throws Exception {
+ FetchResult res = fetchChunks(Lists.newArrayList(BUFFER_CHUNK_INDEX, 12345));
+ assertEquals(res.successChunks, Sets.newHashSet(BUFFER_CHUNK_INDEX));
+ assertEquals(res.failedChunks, Sets.newHashSet(12345));
+ assertBufferListsEqual(res.buffers, Lists.newArrayList(bufferChunk));
+ res.releaseBuffers();
+ }
+
+ private void assertBufferListsEqual(List<ManagedBuffer> list0, List<ManagedBuffer> list1)
+ throws Exception {
+ assertEquals(list0.size(), list1.size());
+ for (int i = 0; i < list0.size(); i ++) {
+ assertBuffersEqual(list0.get(i), list1.get(i));
+ }
+ }
+
+ private void assertBuffersEqual(ManagedBuffer buffer0, ManagedBuffer buffer1) throws Exception {
+ ByteBuffer nio0 = buffer0.nioByteBuffer();
+ ByteBuffer nio1 = buffer1.nioByteBuffer();
+
+ int len = nio0.remaining();
+ assertEquals(nio0.remaining(), nio1.remaining());
+ for (int i = 0; i < len; i ++) {
+ assertEquals(nio0.get(), nio1.get());
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/spark/blob/9e01dcc6/common/network-common/src/test/java/org/apache/spark/network/ProtocolSuite.java
----------------------------------------------------------------------
diff --git a/common/network-common/src/test/java/org/apache/spark/network/ProtocolSuite.java b/common/network-common/src/test/java/org/apache/spark/network/ProtocolSuite.java
new file mode 100644
index 0000000..6c8dd74
--- /dev/null
+++ b/common/network-common/src/test/java/org/apache/spark/network/ProtocolSuite.java
@@ -0,0 +1,127 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network;
+
+import java.util.List;
+
+import com.google.common.primitives.Ints;
+import io.netty.buffer.Unpooled;
+import io.netty.channel.ChannelHandlerContext;
+import io.netty.channel.FileRegion;
+import io.netty.channel.embedded.EmbeddedChannel;
+import io.netty.handler.codec.MessageToMessageEncoder;
+import org.junit.Test;
+
+import static org.junit.Assert.assertEquals;
+
+import org.apache.spark.network.protocol.ChunkFetchFailure;
+import org.apache.spark.network.protocol.ChunkFetchRequest;
+import org.apache.spark.network.protocol.ChunkFetchSuccess;
+import org.apache.spark.network.protocol.Message;
+import org.apache.spark.network.protocol.MessageDecoder;
+import org.apache.spark.network.protocol.MessageEncoder;
+import org.apache.spark.network.protocol.OneWayMessage;
+import org.apache.spark.network.protocol.RpcFailure;
+import org.apache.spark.network.protocol.RpcRequest;
+import org.apache.spark.network.protocol.RpcResponse;
+import org.apache.spark.network.protocol.StreamChunkId;
+import org.apache.spark.network.protocol.StreamFailure;
+import org.apache.spark.network.protocol.StreamRequest;
+import org.apache.spark.network.protocol.StreamResponse;
+import org.apache.spark.network.util.ByteArrayWritableChannel;
+import org.apache.spark.network.util.NettyUtils;
+
+public class ProtocolSuite {
+ private void testServerToClient(Message msg) {
+ EmbeddedChannel serverChannel = new EmbeddedChannel(new FileRegionEncoder(),
+ new MessageEncoder());
+ serverChannel.writeOutbound(msg);
+
+ EmbeddedChannel clientChannel = new EmbeddedChannel(
+ NettyUtils.createFrameDecoder(), new MessageDecoder());
+
+ while (!serverChannel.outboundMessages().isEmpty()) {
+ clientChannel.writeInbound(serverChannel.readOutbound());
+ }
+
+ assertEquals(1, clientChannel.inboundMessages().size());
+ assertEquals(msg, clientChannel.readInbound());
+ }
+
+ private void testClientToServer(Message msg) {
+ EmbeddedChannel clientChannel = new EmbeddedChannel(new FileRegionEncoder(),
+ new MessageEncoder());
+ clientChannel.writeOutbound(msg);
+
+ EmbeddedChannel serverChannel = new EmbeddedChannel(
+ NettyUtils.createFrameDecoder(), new MessageDecoder());
+
+ while (!clientChannel.outboundMessages().isEmpty()) {
+ serverChannel.writeInbound(clientChannel.readOutbound());
+ }
+
+ assertEquals(1, serverChannel.inboundMessages().size());
+ assertEquals(msg, serverChannel.readInbound());
+ }
+
+ @Test
+ public void requests() {
+ testClientToServer(new ChunkFetchRequest(new StreamChunkId(1, 2)));
+ testClientToServer(new RpcRequest(12345, new TestManagedBuffer(0)));
+ testClientToServer(new RpcRequest(12345, new TestManagedBuffer(10)));
+ testClientToServer(new StreamRequest("abcde"));
+ testClientToServer(new OneWayMessage(new TestManagedBuffer(10)));
+ }
+
+ @Test
+ public void responses() {
+ testServerToClient(new ChunkFetchSuccess(new StreamChunkId(1, 2), new TestManagedBuffer(10)));
+ testServerToClient(new ChunkFetchSuccess(new StreamChunkId(1, 2), new TestManagedBuffer(0)));
+ testServerToClient(new ChunkFetchFailure(new StreamChunkId(1, 2), "this is an error"));
+ testServerToClient(new ChunkFetchFailure(new StreamChunkId(1, 2), ""));
+ testServerToClient(new RpcResponse(12345, new TestManagedBuffer(0)));
+ testServerToClient(new RpcResponse(12345, new TestManagedBuffer(100)));
+ testServerToClient(new RpcFailure(0, "this is an error"));
+ testServerToClient(new RpcFailure(0, ""));
+ // Note: buffer size must be "0" since StreamResponse's buffer is written differently to the
+ // channel and cannot be tested like this.
+ testServerToClient(new StreamResponse("anId", 12345L, new TestManagedBuffer(0)));
+ testServerToClient(new StreamFailure("anId", "this is an error"));
+ }
+
+ /**
+ * Handler to transform a FileRegion into a byte buffer. EmbeddedChannel doesn't actually transfer
+ * bytes, but messages, so this is needed so that the frame decoder on the receiving side can
+ * understand what MessageWithHeader actually contains.
+ */
+ private static class FileRegionEncoder extends MessageToMessageEncoder<FileRegion> {
+
+ @Override
+ public void encode(ChannelHandlerContext ctx, FileRegion in, List<Object> out)
+ throws Exception {
+
+ ByteArrayWritableChannel channel = new ByteArrayWritableChannel(Ints.checkedCast(in.count()));
+ while (in.transfered() < in.count()) {
+ in.transferTo(channel, in.transfered());
+ }
+ out.add(Unpooled.wrappedBuffer(channel.getData()));
+ }
+
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/spark/blob/9e01dcc6/common/network-common/src/test/java/org/apache/spark/network/RequestTimeoutIntegrationSuite.java
----------------------------------------------------------------------
diff --git a/common/network-common/src/test/java/org/apache/spark/network/RequestTimeoutIntegrationSuite.java b/common/network-common/src/test/java/org/apache/spark/network/RequestTimeoutIntegrationSuite.java
new file mode 100644
index 0000000..f9b5bf9
--- /dev/null
+++ b/common/network-common/src/test/java/org/apache/spark/network/RequestTimeoutIntegrationSuite.java
@@ -0,0 +1,288 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network;
+
+import com.google.common.collect.Maps;
+import com.google.common.util.concurrent.Uninterruptibles;
+import org.apache.spark.network.buffer.ManagedBuffer;
+import org.apache.spark.network.buffer.NioManagedBuffer;
+import org.apache.spark.network.client.ChunkReceivedCallback;
+import org.apache.spark.network.client.RpcResponseCallback;
+import org.apache.spark.network.client.TransportClient;
+import org.apache.spark.network.client.TransportClientFactory;
+import org.apache.spark.network.server.RpcHandler;
+import org.apache.spark.network.server.StreamManager;
+import org.apache.spark.network.server.TransportServer;
+import org.apache.spark.network.util.MapConfigProvider;
+import org.apache.spark.network.util.TransportConf;
+import org.junit.*;
+import static org.junit.Assert.*;
+
+import java.io.IOException;
+import java.nio.ByteBuffer;
+import java.util.*;
+import java.util.concurrent.Semaphore;
+import java.util.concurrent.TimeUnit;
+
+/**
+ * Suite which ensures that requests that go without a response for the network timeout period are
+ * failed, and the connection closed.
+ *
+ * In this suite, we use 2 seconds as the connection timeout, with some slack given in the tests,
+ * to ensure stability in different test environments.
+ */
+public class RequestTimeoutIntegrationSuite {
+
+ private TransportServer server;
+ private TransportClientFactory clientFactory;
+
+ private StreamManager defaultManager;
+ private TransportConf conf;
+
+ // A large timeout that "shouldn't happen", for the sake of faulty tests not hanging forever.
+ private final int FOREVER = 60 * 1000;
+
+ @Before
+ public void setUp() throws Exception {
+ Map<String, String> configMap = Maps.newHashMap();
+ configMap.put("spark.shuffle.io.connectionTimeout", "2s");
+ conf = new TransportConf("shuffle", new MapConfigProvider(configMap));
+
+ defaultManager = new StreamManager() {
+ @Override
+ public ManagedBuffer getChunk(long streamId, int chunkIndex) {
+ throw new UnsupportedOperationException();
+ }
+ };
+ }
+
+ @After
+ public void tearDown() {
+ if (server != null) {
+ server.close();
+ }
+ if (clientFactory != null) {
+ clientFactory.close();
+ }
+ }
+
+ // Basic suite: First request completes quickly, and second waits for longer than network timeout.
+ @Test
+ public void timeoutInactiveRequests() throws Exception {
+ final Semaphore semaphore = new Semaphore(1);
+ final int responseSize = 16;
+ RpcHandler handler = new RpcHandler() {
+ @Override
+ public void receive(
+ TransportClient client,
+ ByteBuffer message,
+ RpcResponseCallback callback) {
+ try {
+ semaphore.tryAcquire(FOREVER, TimeUnit.MILLISECONDS);
+ callback.onSuccess(ByteBuffer.allocate(responseSize));
+ } catch (InterruptedException e) {
+ // do nothing
+ }
+ }
+
+ @Override
+ public StreamManager getStreamManager() {
+ return defaultManager;
+ }
+ };
+
+ TransportContext context = new TransportContext(conf, handler);
+ server = context.createServer();
+ clientFactory = context.createClientFactory();
+ TransportClient client = clientFactory.createClient(TestUtils.getLocalHost(), server.getPort());
+
+ // First completes quickly (semaphore starts at 1).
+ TestCallback callback0 = new TestCallback();
+ synchronized (callback0) {
+ client.sendRpc(ByteBuffer.allocate(0), callback0);
+ callback0.wait(FOREVER);
+ assertEquals(responseSize, callback0.successLength);
+ }
+
+ // Second times out after 2 seconds, with slack. Must be IOException.
+ TestCallback callback1 = new TestCallback();
+ synchronized (callback1) {
+ client.sendRpc(ByteBuffer.allocate(0), callback1);
+ callback1.wait(4 * 1000);
+ assert (callback1.failure != null);
+ assert (callback1.failure instanceof IOException);
+ }
+ semaphore.release();
+ }
+
+ // A timeout will cause the connection to be closed, invalidating the current TransportClient.
+ // It should be the case that requesting a client from the factory produces a new, valid one.
+ @Test
+ public void timeoutCleanlyClosesClient() throws Exception {
+ final Semaphore semaphore = new Semaphore(0);
+ final int responseSize = 16;
+ RpcHandler handler = new RpcHandler() {
+ @Override
+ public void receive(
+ TransportClient client,
+ ByteBuffer message,
+ RpcResponseCallback callback) {
+ try {
+ semaphore.tryAcquire(FOREVER, TimeUnit.MILLISECONDS);
+ callback.onSuccess(ByteBuffer.allocate(responseSize));
+ } catch (InterruptedException e) {
+ // do nothing
+ }
+ }
+
+ @Override
+ public StreamManager getStreamManager() {
+ return defaultManager;
+ }
+ };
+
+ TransportContext context = new TransportContext(conf, handler);
+ server = context.createServer();
+ clientFactory = context.createClientFactory();
+
+ // First request should eventually fail.
+ TransportClient client0 =
+ clientFactory.createClient(TestUtils.getLocalHost(), server.getPort());
+ TestCallback callback0 = new TestCallback();
+ synchronized (callback0) {
+ client0.sendRpc(ByteBuffer.allocate(0), callback0);
+ callback0.wait(FOREVER);
+ assert (callback0.failure instanceof IOException);
+ assert (!client0.isActive());
+ }
+
+ // Increment the semaphore and the second request should succeed quickly.
+ semaphore.release(2);
+ TransportClient client1 =
+ clientFactory.createClient(TestUtils.getLocalHost(), server.getPort());
+ TestCallback callback1 = new TestCallback();
+ synchronized (callback1) {
+ client1.sendRpc(ByteBuffer.allocate(0), callback1);
+ callback1.wait(FOREVER);
+ assertEquals(responseSize, callback1.successLength);
+ assertNull(callback1.failure);
+ }
+ }
+
+ // The timeout is relative to the LAST request sent, which is kinda weird, but still.
+ // This test also makes sure the timeout works for Fetch requests as well as RPCs.
+ @Test
+ public void furtherRequestsDelay() throws Exception {
+ final byte[] response = new byte[16];
+ final StreamManager manager = new StreamManager() {
+ @Override
+ public ManagedBuffer getChunk(long streamId, int chunkIndex) {
+ Uninterruptibles.sleepUninterruptibly(FOREVER, TimeUnit.MILLISECONDS);
+ return new NioManagedBuffer(ByteBuffer.wrap(response));
+ }
+ };
+ RpcHandler handler = new RpcHandler() {
+ @Override
+ public void receive(
+ TransportClient client,
+ ByteBuffer message,
+ RpcResponseCallback callback) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public StreamManager getStreamManager() {
+ return manager;
+ }
+ };
+
+ TransportContext context = new TransportContext(conf, handler);
+ server = context.createServer();
+ clientFactory = context.createClientFactory();
+ TransportClient client = clientFactory.createClient(TestUtils.getLocalHost(), server.getPort());
+
+ // Send one request, which will eventually fail.
+ TestCallback callback0 = new TestCallback();
+ client.fetchChunk(0, 0, callback0);
+ Uninterruptibles.sleepUninterruptibly(1200, TimeUnit.MILLISECONDS);
+
+ // Send a second request before the first has failed.
+ TestCallback callback1 = new TestCallback();
+ client.fetchChunk(0, 1, callback1);
+ Uninterruptibles.sleepUninterruptibly(1200, TimeUnit.MILLISECONDS);
+
+ synchronized (callback0) {
+ // not complete yet, but should complete soon
+ assertEquals(-1, callback0.successLength);
+ assertNull(callback0.failure);
+ callback0.wait(2 * 1000);
+ assertTrue(callback0.failure instanceof IOException);
+ }
+
+ synchronized (callback1) {
+ // failed at same time as previous
+ assert (callback0.failure instanceof IOException);
+ }
+ }
+
+ /**
+ * Callback which sets 'success' or 'failure' on completion.
+ * Additionally notifies all waiters on this callback when invoked.
+ */
+ class TestCallback implements RpcResponseCallback, ChunkReceivedCallback {
+
+ int successLength = -1;
+ Throwable failure;
+
+ @Override
+ public void onSuccess(ByteBuffer response) {
+ synchronized(this) {
+ successLength = response.remaining();
+ this.notifyAll();
+ }
+ }
+
+ @Override
+ public void onFailure(Throwable e) {
+ synchronized(this) {
+ failure = e;
+ this.notifyAll();
+ }
+ }
+
+ @Override
+ public void onSuccess(int chunkIndex, ManagedBuffer buffer) {
+ synchronized(this) {
+ try {
+ successLength = buffer.nioByteBuffer().remaining();
+ this.notifyAll();
+ } catch (IOException e) {
+ // weird
+ }
+ }
+ }
+
+ @Override
+ public void onFailure(int chunkIndex, Throwable e) {
+ synchronized(this) {
+ failure = e;
+ this.notifyAll();
+ }
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/spark/blob/9e01dcc6/common/network-common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java
----------------------------------------------------------------------
diff --git a/common/network-common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java b/common/network-common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java
new file mode 100644
index 0000000..9e9be98
--- /dev/null
+++ b/common/network-common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java
@@ -0,0 +1,215 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network;
+
+import java.nio.ByteBuffer;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.HashSet;
+import java.util.Iterator;
+import java.util.List;
+import java.util.Set;
+import java.util.concurrent.Semaphore;
+import java.util.concurrent.TimeUnit;
+
+import com.google.common.collect.Sets;
+import org.junit.AfterClass;
+import org.junit.BeforeClass;
+import org.junit.Test;
+
+import static org.junit.Assert.*;
+
+import org.apache.spark.network.client.RpcResponseCallback;
+import org.apache.spark.network.client.TransportClient;
+import org.apache.spark.network.client.TransportClientFactory;
+import org.apache.spark.network.server.OneForOneStreamManager;
+import org.apache.spark.network.server.RpcHandler;
+import org.apache.spark.network.server.StreamManager;
+import org.apache.spark.network.server.TransportServer;
+import org.apache.spark.network.util.JavaUtils;
+import org.apache.spark.network.util.SystemPropertyConfigProvider;
+import org.apache.spark.network.util.TransportConf;
+
+public class RpcIntegrationSuite {
+ static TransportServer server;
+ static TransportClientFactory clientFactory;
+ static RpcHandler rpcHandler;
+ static List<String> oneWayMsgs;
+
+ @BeforeClass
+ public static void setUp() throws Exception {
+ TransportConf conf = new TransportConf("shuffle", new SystemPropertyConfigProvider());
+ rpcHandler = new RpcHandler() {
+ @Override
+ public void receive(
+ TransportClient client,
+ ByteBuffer message,
+ RpcResponseCallback callback) {
+ String msg = JavaUtils.bytesToString(message);
+ String[] parts = msg.split("/");
+ if (parts[0].equals("hello")) {
+ callback.onSuccess(JavaUtils.stringToBytes("Hello, " + parts[1] + "!"));
+ } else if (parts[0].equals("return error")) {
+ callback.onFailure(new RuntimeException("Returned: " + parts[1]));
+ } else if (parts[0].equals("throw error")) {
+ throw new RuntimeException("Thrown: " + parts[1]);
+ }
+ }
+
+ @Override
+ public void receive(TransportClient client, ByteBuffer message) {
+ oneWayMsgs.add(JavaUtils.bytesToString(message));
+ }
+
+ @Override
+ public StreamManager getStreamManager() { return new OneForOneStreamManager(); }
+ };
+ TransportContext context = new TransportContext(conf, rpcHandler);
+ server = context.createServer();
+ clientFactory = context.createClientFactory();
+ oneWayMsgs = new ArrayList<>();
+ }
+
+ @AfterClass
+ public static void tearDown() {
+ server.close();
+ clientFactory.close();
+ }
+
+ class RpcResult {
+ public Set<String> successMessages;
+ public Set<String> errorMessages;
+ }
+
+ private RpcResult sendRPC(String ... commands) throws Exception {
+ TransportClient client = clientFactory.createClient(TestUtils.getLocalHost(), server.getPort());
+ final Semaphore sem = new Semaphore(0);
+
+ final RpcResult res = new RpcResult();
+ res.successMessages = Collections.synchronizedSet(new HashSet<String>());
+ res.errorMessages = Collections.synchronizedSet(new HashSet<String>());
+
+ RpcResponseCallback callback = new RpcResponseCallback() {
+ @Override
+ public void onSuccess(ByteBuffer message) {
+ String response = JavaUtils.bytesToString(message);
+ res.successMessages.add(response);
+ sem.release();
+ }
+
+ @Override
+ public void onFailure(Throwable e) {
+ res.errorMessages.add(e.getMessage());
+ sem.release();
+ }
+ };
+
+ for (String command : commands) {
+ client.sendRpc(JavaUtils.stringToBytes(command), callback);
+ }
+
+ if (!sem.tryAcquire(commands.length, 5, TimeUnit.SECONDS)) {
+ fail("Timeout getting response from the server");
+ }
+ client.close();
+ return res;
+ }
+
+ @Test
+ public void singleRPC() throws Exception {
+ RpcResult res = sendRPC("hello/Aaron");
+ assertEquals(res.successMessages, Sets.newHashSet("Hello, Aaron!"));
+ assertTrue(res.errorMessages.isEmpty());
+ }
+
+ @Test
+ public void doubleRPC() throws Exception {
+ RpcResult res = sendRPC("hello/Aaron", "hello/Reynold");
+ assertEquals(res.successMessages, Sets.newHashSet("Hello, Aaron!", "Hello, Reynold!"));
+ assertTrue(res.errorMessages.isEmpty());
+ }
+
+ @Test
+ public void returnErrorRPC() throws Exception {
+ RpcResult res = sendRPC("return error/OK");
+ assertTrue(res.successMessages.isEmpty());
+ assertErrorsContain(res.errorMessages, Sets.newHashSet("Returned: OK"));
+ }
+
+ @Test
+ public void throwErrorRPC() throws Exception {
+ RpcResult res = sendRPC("throw error/uh-oh");
+ assertTrue(res.successMessages.isEmpty());
+ assertErrorsContain(res.errorMessages, Sets.newHashSet("Thrown: uh-oh"));
+ }
+
+ @Test
+ public void doubleTrouble() throws Exception {
+ RpcResult res = sendRPC("return error/OK", "throw error/uh-oh");
+ assertTrue(res.successMessages.isEmpty());
+ assertErrorsContain(res.errorMessages, Sets.newHashSet("Returned: OK", "Thrown: uh-oh"));
+ }
+
+ @Test
+ public void sendSuccessAndFailure() throws Exception {
+ RpcResult res = sendRPC("hello/Bob", "throw error/the", "hello/Builder", "return error/!");
+ assertEquals(res.successMessages, Sets.newHashSet("Hello, Bob!", "Hello, Builder!"));
+ assertErrorsContain(res.errorMessages, Sets.newHashSet("Thrown: the", "Returned: !"));
+ }
+
+ @Test
+ public void sendOneWayMessage() throws Exception {
+ final String message = "no reply";
+ TransportClient client = clientFactory.createClient(TestUtils.getLocalHost(), server.getPort());
+ try {
+ client.send(JavaUtils.stringToBytes(message));
+ assertEquals(0, client.getHandler().numOutstandingRequests());
+
+ // Make sure the message arrives.
+ long deadline = System.nanoTime() + TimeUnit.NANOSECONDS.convert(10, TimeUnit.SECONDS);
+ while (System.nanoTime() < deadline && oneWayMsgs.size() == 0) {
+ TimeUnit.MILLISECONDS.sleep(10);
+ }
+
+ assertEquals(1, oneWayMsgs.size());
+ assertEquals(message, oneWayMsgs.get(0));
+ } finally {
+ client.close();
+ }
+ }
+
+ private void assertErrorsContain(Set<String> errors, Set<String> contains) {
+ assertEquals(contains.size(), errors.size());
+
+ Set<String> remainingErrors = Sets.newHashSet(errors);
+ for (String contain : contains) {
+ Iterator<String> it = remainingErrors.iterator();
+ boolean foundMatch = false;
+ while (it.hasNext()) {
+ if (it.next().contains(contain)) {
+ it.remove();
+ foundMatch = true;
+ break;
+ }
+ }
+ assertTrue("Could not find error containing " + contain + "; errors: " + errors, foundMatch);
+ }
+
+ assertTrue(remainingErrors.isEmpty());
+ }
+}
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org