You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@zookeeper.apache.org by an...@apache.org on 2018/11/27 09:02:32 UTC
[2/2] zookeeper git commit: ZOOKEEPER-3172: Quorum TLS - fix port
unification to allow rolling upgrades
ZOOKEEPER-3172: Quorum TLS - fix port unification to allow rolling upgrades
Fix numerous problems with UnifiedServerSocket, such as hanging the accept() thread when the client doesn't send any data or crashing if less than 5 bytes are read from the socket in the initial read.
Re-enable the "portUnification" config option.
## Fixed networking issues/bugs in UnifiedServerSocket
- don't crash the `accept()` thread if the client closes the connection without sending any data
- don't corrupt the connection if the client sends fewer than 5 bytes for the initial read
- delay the detection of TLS vs. plaintext mode until a socket stream is read from or written to. This prevents the `accept()` thread from getting blocked on a `read()` operation from the newly connected socket.
- prepending 5 bytes to `PrependableSocket` and then trying to read >5 bytes would only return the first 5 bytes, even if more bytes were available. This is fixed.
Author: Ilya Maykov <il...@fb.com>
Reviewers: andor@apache.org
Closes #679 from ivmaykov/ZOOKEEPER-3172
Project: http://git-wip-us.apache.org/repos/asf/zookeeper/repo
Commit: http://git-wip-us.apache.org/repos/asf/zookeeper/commit/64104eae
Tree: http://git-wip-us.apache.org/repos/asf/zookeeper/tree/64104eae
Diff: http://git-wip-us.apache.org/repos/asf/zookeeper/diff/64104eae
Branch: refs/heads/master
Commit: 64104eaeaa6508f052edfd39c24243a8e26039dc
Parents: 91c6cb2
Author: Ilya Maykov <il...@fb.com>
Authored: Tue Nov 27 10:02:24 2018 +0100
Committer: Andor Molnar <an...@apache.org>
Committed: Tue Nov 27 10:02:24 2018 +0100
----------------------------------------------------------------------
.../org/apache/zookeeper/common/X509Util.java | 55 +-
.../org/apache/zookeeper/common/ZKConfig.java | 2 +
.../apache/zookeeper/server/quorum/Leader.java | 29 +-
.../apache/zookeeper/server/quorum/Learner.java | 9 +-
.../server/quorum/PrependableSocket.java | 29 +-
.../server/quorum/QuorumCnxManager.java | 34 +-
.../zookeeper/server/quorum/QuorumPeer.java | 8 +
.../server/quorum/QuorumPeerConfig.java | 5 +-
.../server/quorum/UnifiedServerSocket.java | 738 ++++++++++++++++++-
.../apache/zookeeper/common/X509UtilTest.java | 28 +
.../zookeeper/server/quorum/QuorumSSLTest.java | 2 -
.../UnifiedServerSocketModeDetectionTest.java | 404 ++++++++++
.../server/quorum/UnifiedServerSocketTest.java | 608 ++++++++++++---
13 files changed, 1794 insertions(+), 157 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/zookeeper/blob/64104eae/zookeeper-server/src/main/java/org/apache/zookeeper/common/X509Util.java
----------------------------------------------------------------------
diff --git a/zookeeper-server/src/main/java/org/apache/zookeeper/common/X509Util.java b/zookeeper-server/src/main/java/org/apache/zookeeper/common/X509Util.java
index 5b97ac6..e3625a5 100644
--- a/zookeeper-server/src/main/java/org/apache/zookeeper/common/X509Util.java
+++ b/zookeeper-server/src/main/java/org/apache/zookeeper/common/X509Util.java
@@ -18,6 +18,7 @@
package org.apache.zookeeper.common;
+import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.net.Socket;
import java.security.GeneralSecurityException;
@@ -74,6 +75,8 @@ public abstract class X509Util {
"TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256"
};
+ public static final int DEFAULT_HANDSHAKE_DETECTION_TIMEOUT_MILLIS = 5000;
+
private String sslProtocolProperty = getConfigPrefix() + "protocol";
private String cipherSuitesProperty = getConfigPrefix() + "ciphersuites";
private String sslKeystoreLocationProperty = getConfigPrefix() + "keyStore.location";
@@ -85,6 +88,7 @@ public abstract class X509Util {
private String sslHostnameVerificationEnabledProperty = getConfigPrefix() + "hostnameVerification";
private String sslCrlEnabledProperty = getConfigPrefix() + "crl";
private String sslOcspEnabledProperty = getConfigPrefix() + "ocsp";
+ private String sslHandshakeDetectionTimeoutMillisProperty = getConfigPrefix() + "handshakeDetectionTimeoutMillis";
private String[] cipherSuites;
@@ -146,6 +150,16 @@ public abstract class X509Util {
return sslOcspEnabledProperty;
}
+ /**
+ * Returns the config property key that controls the amount of time, in milliseconds, that the first
+ * UnifiedServerSocket read operation will block for when trying to detect the client mode (TLS or PLAINTEXT).
+ *
+ * @return the config property key.
+ */
+ public String getSslHandshakeDetectionTimeoutMillisProperty() {
+ return sslHandshakeDetectionTimeoutMillisProperty;
+ }
+
public SSLContext getDefaultSSLContext() throws X509Exception.SSLContextException {
SSLContext result = defaultSSLContext.get();
if (result == null) {
@@ -168,6 +182,31 @@ public abstract class X509Util {
return createSSLContext(config);
}
+ /**
+ * Returns the max amount of time, in milliseconds, that the first UnifiedServerSocket read() operation should
+ * block for when trying to detect the client mode (TLS or PLAINTEXT).
+ * Defaults to {@link X509Util#DEFAULT_HANDSHAKE_DETECTION_TIMEOUT_MILLIS}.
+ *
+ * @return the handshake detection timeout, in milliseconds.
+ */
+ public int getSslHandshakeTimeoutMillis() {
+ String propertyString = System.getProperty(getSslHandshakeDetectionTimeoutMillisProperty());
+ int result;
+ if (propertyString == null) {
+ result = DEFAULT_HANDSHAKE_DETECTION_TIMEOUT_MILLIS;
+ } else {
+ result = Integer.parseInt(propertyString);
+ if (result < 1) {
+ // Timeout of 0 is not allowed, since an infinite timeout can permanently lock up an
+ // accept() thread.
+ LOG.warn("Invalid value for " + getSslHandshakeDetectionTimeoutMillisProperty() + ": " + result +
+ ", using the default value of " + DEFAULT_HANDSHAKE_DETECTION_TIMEOUT_MILLIS);
+ result = DEFAULT_HANDSHAKE_DETECTION_TIMEOUT_MILLIS;
+ }
+ }
+ return result;
+ }
+
public SSLContext createSSLContext(ZKConfig config) throws SSLContextException {
KeyManager[] keyManagers = null;
TrustManager[] trustManagers = null;
@@ -350,14 +389,22 @@ public abstract class X509Util {
public SSLSocket createSSLSocket() throws X509Exception, IOException {
SSLSocket sslSocket = (SSLSocket) getDefaultSSLContext().getSocketFactory().createSocket();
configureSSLSocket(sslSocket);
-
+ sslSocket.setUseClientMode(true);
return sslSocket;
}
- public SSLSocket createSSLSocket(Socket socket) throws X509Exception, IOException {
- SSLSocket sslSocket = (SSLSocket) getDefaultSSLContext().getSocketFactory().createSocket(socket, null, socket.getPort(), true);
+ public SSLSocket createSSLSocket(Socket socket, byte[] pushbackBytes) throws X509Exception, IOException {
+ SSLSocket sslSocket;
+ if (pushbackBytes != null && pushbackBytes.length > 0) {
+ sslSocket = (SSLSocket) getDefaultSSLContext().getSocketFactory().createSocket(
+ socket, new ByteArrayInputStream(pushbackBytes), true);
+ } else {
+ sslSocket = (SSLSocket) getDefaultSSLContext().getSocketFactory().createSocket(
+ socket, null, socket.getPort(), true);
+ }
configureSSLSocket(sslSocket);
-
+ sslSocket.setUseClientMode(false);
+ sslSocket.setNeedClientAuth(true);
return sslSocket;
}
http://git-wip-us.apache.org/repos/asf/zookeeper/blob/64104eae/zookeeper-server/src/main/java/org/apache/zookeeper/common/ZKConfig.java
----------------------------------------------------------------------
diff --git a/zookeeper-server/src/main/java/org/apache/zookeeper/common/ZKConfig.java b/zookeeper-server/src/main/java/org/apache/zookeeper/common/ZKConfig.java
index 01bac69..effc0d5 100644
--- a/zookeeper-server/src/main/java/org/apache/zookeeper/common/ZKConfig.java
+++ b/zookeeper-server/src/main/java/org/apache/zookeeper/common/ZKConfig.java
@@ -130,6 +130,8 @@ public class ZKConfig {
System.getProperty(x509Util.getSslCrlEnabledProperty()));
properties.put(x509Util.getSslOcspEnabledProperty(),
System.getProperty(x509Util.getSslOcspEnabledProperty()));
+ properties.put(x509Util.getSslHandshakeDetectionTimeoutMillisProperty(),
+ System.getProperty(x509Util.getSslHandshakeDetectionTimeoutMillisProperty()));
}
/**
http://git-wip-us.apache.org/repos/asf/zookeeper/blob/64104eae/zookeeper-server/src/main/java/org/apache/zookeeper/server/quorum/Leader.java
----------------------------------------------------------------------
diff --git a/zookeeper-server/src/main/java/org/apache/zookeeper/server/quorum/Leader.java b/zookeeper-server/src/main/java/org/apache/zookeeper/server/quorum/Leader.java
index 9270548..0a892b1 100644
--- a/zookeeper-server/src/main/java/org/apache/zookeeper/server/quorum/Leader.java
+++ b/zookeeper-server/src/main/java/org/apache/zookeeper/server/quorum/Leader.java
@@ -42,7 +42,6 @@ import java.util.concurrent.ConcurrentMap;
import javax.security.sasl.SaslException;
import org.apache.zookeeper.ZooDefs.OpCode;
-import org.apache.zookeeper.common.QuorumX509Util;
import org.apache.zookeeper.common.Time;
import org.apache.zookeeper.common.X509Exception;
import org.apache.zookeeper.server.FinalRequestProcessor;
@@ -240,15 +239,15 @@ public class Leader {
try {
if (self.shouldUsePortUnification()) {
if (self.getQuorumListenOnAllIPs()) {
- ss = new UnifiedServerSocket(new QuorumX509Util(), self.getQuorumAddress().getPort());
+ ss = new UnifiedServerSocket(self.getX509Util(), true, self.getQuorumAddress().getPort());
} else {
- ss = new UnifiedServerSocket(new QuorumX509Util());
+ ss = new UnifiedServerSocket(self.getX509Util(), true);
}
} else if (self.isSslQuorum()) {
if (self.getQuorumListenOnAllIPs()) {
- ss = new QuorumX509Util().createSSLServerSocket(self.getQuorumAddress().getPort());
+ ss = self.getX509Util().createSSLServerSocket(self.getQuorumAddress().getPort());
} else {
- ss = new QuorumX509Util().createSSLServerSocket();
+ ss = self.getX509Util().createSSLServerSocket();
}
} else {
if (self.getQuorumListenOnAllIPs()) {
@@ -399,8 +398,10 @@ public class Leader {
public void run() {
try {
while (!stop) {
- try{
- Socket s = ss.accept();
+ Socket s = null;
+ boolean error = false;
+ try {
+ s = ss.accept();
// start with the initLimit, once the ack is processed
// in LearnerHandler switch to the syncLimit
@@ -412,6 +413,7 @@ public class Leader {
LearnerHandler fh = new LearnerHandler(s, is, Leader.this);
fh.start();
} catch (SocketException e) {
+ error = true;
if (stop) {
LOG.info("exception while shutting down acceptor: "
+ e);
@@ -425,6 +427,19 @@ public class Leader {
}
} catch (SaslException e){
LOG.error("Exception while connecting to quorum learner", e);
+ error = true;
+ } catch (Exception e) {
+ error = true;
+ throw e;
+ } finally {
+ // Don't leak sockets on errors
+ if (error && s != null && !s.isClosed()) {
+ try {
+ s.close();
+ } catch (IOException e) {
+ LOG.warn("Error closing socket", e);
+ }
+ }
}
}
} catch (Exception e) {
http://git-wip-us.apache.org/repos/asf/zookeeper/blob/64104eae/zookeeper-server/src/main/java/org/apache/zookeeper/server/quorum/Learner.java
----------------------------------------------------------------------
diff --git a/zookeeper-server/src/main/java/org/apache/zookeeper/server/quorum/Learner.java b/zookeeper-server/src/main/java/org/apache/zookeeper/server/quorum/Learner.java
index c740d53..faaa844 100644
--- a/zookeeper-server/src/main/java/org/apache/zookeeper/server/quorum/Learner.java
+++ b/zookeeper-server/src/main/java/org/apache/zookeeper/server/quorum/Learner.java
@@ -38,9 +38,7 @@ import org.apache.jute.BinaryOutputArchive;
import org.apache.jute.InputArchive;
import org.apache.jute.OutputArchive;
import org.apache.jute.Record;
-import org.apache.zookeeper.common.QuorumX509Util;
import org.apache.zookeeper.common.X509Exception;
-import org.apache.zookeeper.common.X509Util;
import org.apache.zookeeper.server.ExitCode;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -74,8 +72,6 @@ public class Learner {
protected Socket sock;
- protected X509Util x509Util;
-
/**
* Socket getter
* @return
@@ -304,10 +300,7 @@ public class Learner {
private Socket createSocket() throws X509Exception, IOException {
Socket sock;
if (self.isSslQuorum()) {
- if (x509Util == null) {
- x509Util = new QuorumX509Util();
- }
- sock = x509Util.createSSLSocket();
+ sock = self.getX509Util().createSSLSocket();
} else {
sock = new Socket();
}
http://git-wip-us.apache.org/repos/asf/zookeeper/blob/64104eae/zookeeper-server/src/main/java/org/apache/zookeeper/server/quorum/PrependableSocket.java
----------------------------------------------------------------------
diff --git a/zookeeper-server/src/main/java/org/apache/zookeeper/server/quorum/PrependableSocket.java b/zookeeper-server/src/main/java/org/apache/zookeeper/server/quorum/PrependableSocket.java
index a86608f..94a526e 100644
--- a/zookeeper-server/src/main/java/org/apache/zookeeper/server/quorum/PrependableSocket.java
+++ b/zookeeper-server/src/main/java/org/apache/zookeeper/server/quorum/PrependableSocket.java
@@ -18,16 +18,15 @@
package org.apache.zookeeper.server.quorum;
-import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.io.InputStream;
-import java.io.SequenceInputStream;
+import java.io.PushbackInputStream;
import java.net.Socket;
import java.net.SocketImpl;
public class PrependableSocket extends Socket {
- private SequenceInputStream sequenceInputStream;
+ private PushbackInputStream pushbackInputStream;
public PrependableSocket(SocketImpl base) throws IOException {
super(base);
@@ -35,15 +34,31 @@ public class PrependableSocket extends Socket {
@Override
public InputStream getInputStream() throws IOException {
- if (sequenceInputStream == null) {
+ if (pushbackInputStream == null) {
return super.getInputStream();
}
- return sequenceInputStream;
+ return pushbackInputStream;
}
- public void prependToInputStream(byte[] bytes) throws IOException {
- sequenceInputStream = new SequenceInputStream(new ByteArrayInputStream(bytes), getInputStream());
+ /**
+ * Prepend some bytes that have already been read back to the socket's input stream. Note that this method can be
+ * called at most once with a non-0 length per socket instance.
+ * @param bytes the bytes to prepend.
+ * @param offset offset in the byte array to start at.
+ * @param length number of bytes to prepend.
+ * @throws IOException if this method was already called on the socket instance, or if super.getInputStream() throws.
+ */
+ public void prependToInputStream(byte[] bytes, int offset, int length) throws IOException {
+ if (length == 0) {
+ return; // nothing to prepend
+ }
+ if (pushbackInputStream != null) {
+ throw new IOException("prependToInputStream() called more than once");
+ }
+ PushbackInputStream pushbackInputStream = new PushbackInputStream(getInputStream(), length);
+ pushbackInputStream.unread(bytes, offset, length);
+ this.pushbackInputStream = pushbackInputStream;
}
}
\ No newline at end of file
http://git-wip-us.apache.org/repos/asf/zookeeper/blob/64104eae/zookeeper-server/src/main/java/org/apache/zookeeper/server/quorum/QuorumCnxManager.java
----------------------------------------------------------------------
diff --git a/zookeeper-server/src/main/java/org/apache/zookeeper/server/quorum/QuorumCnxManager.java b/zookeeper-server/src/main/java/org/apache/zookeeper/server/quorum/QuorumCnxManager.java
index 8b91023..4175f3c 100644
--- a/zookeeper-server/src/main/java/org/apache/zookeeper/server/quorum/QuorumCnxManager.java
+++ b/zookeeper-server/src/main/java/org/apache/zookeeper/server/quorum/QuorumCnxManager.java
@@ -47,9 +47,7 @@ import java.util.NoSuchElementException;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicLong;
-import org.apache.zookeeper.common.QuorumX509Util;
import org.apache.zookeeper.common.X509Exception;
-import org.apache.zookeeper.common.X509Util;
import org.apache.zookeeper.server.ExitCode;
import org.apache.zookeeper.server.quorum.QuorumPeerConfig.ConfigException;
import org.apache.zookeeper.server.util.ConfigUtils;
@@ -175,9 +173,6 @@ public class QuorumCnxManager {
*/
private final boolean tcpKeepAlive = Boolean.getBoolean("zookeeper.tcpKeepAlive");
-
- private X509Util x509Util;
-
static public class Message {
Message(ByteBuffer buffer, long sid) {
this.buffer = buffer;
@@ -291,8 +286,6 @@ public class QuorumCnxManager {
// Starts listener thread that waits for connection requests
listener = new Listener();
listener.setName("QuorumPeerListener");
-
- x509Util = new QuorumX509Util();
}
private void initializeAuth(final long mySid,
@@ -655,17 +648,18 @@ public class QuorumCnxManager {
try {
LOG.debug("Opening channel to server " + sid);
if (self.isSslQuorum()) {
- SSLSocket sslSock = x509Util.createSSLSocket();
- setSockOpts(sslSock);
- sslSock.connect(electionAddr, cnxTO);
- sslSock.startHandshake();
- sock = sslSock;
- } else {
- sock = new Socket();
- setSockOpts(sock);
- sock.connect(electionAddr, cnxTO);
- }
- LOG.debug("Connected to server " + sid);
+ SSLSocket sslSock = self.getX509Util().createSSLSocket();
+ setSockOpts(sslSock);
+ sslSock.connect(electionAddr, cnxTO);
+ sslSock.startHandshake();
+ sock = sslSock;
+ } else {
+ sock = new Socket();
+ setSockOpts(sock);
+ sock.connect(electionAddr, cnxTO);
+
+ }
+ LOG.debug("Connected to server " + sid);
// Sends connection request asynchronously if the quorum
// sasl authentication is enabled. This is required because
// sasl server authentication process may take few seconds to
@@ -876,9 +870,9 @@ public class QuorumCnxManager {
while((!shutdown) && (numRetries < 3)){
try {
if (self.shouldUsePortUnification()) {
- ss = new UnifiedServerSocket(x509Util);
+ ss = new UnifiedServerSocket(self.getX509Util(), true);
} else if (self.isSslQuorum()) {
- ss = x509Util.createSSLServerSocket();
+ ss = self.getX509Util().createSSLServerSocket();
} else {
ss = new ServerSocket();
}
http://git-wip-us.apache.org/repos/asf/zookeeper/blob/64104eae/zookeeper-server/src/main/java/org/apache/zookeeper/server/quorum/QuorumPeer.java
----------------------------------------------------------------------
diff --git a/zookeeper-server/src/main/java/org/apache/zookeeper/server/quorum/QuorumPeer.java b/zookeeper-server/src/main/java/org/apache/zookeeper/server/quorum/QuorumPeer.java
index 136a538..7abde4b 100644
--- a/zookeeper-server/src/main/java/org/apache/zookeeper/server/quorum/QuorumPeer.java
+++ b/zookeeper-server/src/main/java/org/apache/zookeeper/server/quorum/QuorumPeer.java
@@ -47,6 +47,7 @@ import javax.security.sasl.SaslException;
import org.apache.zookeeper.KeeperException.BadArgumentsException;
import org.apache.zookeeper.common.AtomicFileWritingIdiom;
import org.apache.zookeeper.common.AtomicFileWritingIdiom.WriterStatement;
+import org.apache.zookeeper.common.QuorumX509Util;
import org.apache.zookeeper.common.Time;
import org.apache.zookeeper.common.X509Exception;
import org.apache.zookeeper.jmx.MBeanRegistry;
@@ -479,6 +480,12 @@ public class QuorumPeer extends ZooKeeperThread implements QuorumStats.Provider
return shouldUsePortUnification;
}
+ private final QuorumX509Util x509Util;
+
+ QuorumX509Util getX509Util() {
+ return x509Util;
+ }
+
/**
* This is who I think the leader currently is.
*/
@@ -801,6 +808,7 @@ public class QuorumPeer extends ZooKeeperThread implements QuorumStats.Provider
quorumStats = new QuorumStats(this);
jmxRemotePeerBean = new HashMap<Long, RemotePeerBean>();
adminServer = AdminServerFactory.createAdminServer();
+ x509Util = new QuorumX509Util();
initialize();
}
http://git-wip-us.apache.org/repos/asf/zookeeper/blob/64104eae/zookeeper-server/src/main/java/org/apache/zookeeper/server/quorum/QuorumPeerConfig.java
----------------------------------------------------------------------
diff --git a/zookeeper-server/src/main/java/org/apache/zookeeper/server/quorum/QuorumPeerConfig.java b/zookeeper-server/src/main/java/org/apache/zookeeper/server/quorum/QuorumPeerConfig.java
index 45463b1..aee5efc 100644
--- a/zookeeper-server/src/main/java/org/apache/zookeeper/server/quorum/QuorumPeerConfig.java
+++ b/zookeeper-server/src/main/java/org/apache/zookeeper/server/quorum/QuorumPeerConfig.java
@@ -315,9 +315,8 @@ public class QuorumPeerConfig {
}
} else if (key.equals("sslQuorum")){
sslQuorum = Boolean.parseBoolean(value);
-// TODO: UnifiedServerSocket is currently buggy, will be fixed when @ivmaykov's PRs are merged. Disable port unification until then.
-// } else if (key.equals("portUnification")){
-// shouldUsePortUnification = Boolean.parseBoolean(value);
+ } else if (key.equals("portUnification")){
+ shouldUsePortUnification = Boolean.parseBoolean(value);
} else if ((key.startsWith("server.") || key.startsWith("group") || key.startsWith("weight")) && zkProp.containsKey("dynamicConfigFile")) {
throw new ConfigException("parameter: " + key + " must be in a separate dynamic config file");
} else if (key.equals(QuorumAuth.QUORUM_SASL_AUTH_ENABLED)) {
http://git-wip-us.apache.org/repos/asf/zookeeper/blob/64104eae/zookeeper-server/src/main/java/org/apache/zookeeper/server/quorum/UnifiedServerSocket.java
----------------------------------------------------------------------
diff --git a/zookeeper-server/src/main/java/org/apache/zookeeper/server/quorum/UnifiedServerSocket.java b/zookeeper-server/src/main/java/org/apache/zookeeper/server/quorum/UnifiedServerSocket.java
index d1e3ba5..bbe245f 100644
--- a/zookeeper-server/src/main/java/org/apache/zookeeper/server/quorum/UnifiedServerSocket.java
+++ b/zookeeper-server/src/main/java/org/apache/zookeeper/server/quorum/UnifiedServerSocket.java
@@ -27,23 +27,111 @@ import org.slf4j.LoggerFactory;
import javax.net.ssl.SSLSocket;
import java.io.IOException;
+import java.io.InputStream;
+import java.io.OutputStream;
+import java.net.InetAddress;
import java.net.ServerSocket;
import java.net.Socket;
+import java.net.SocketAddress;
import java.net.SocketException;
+import java.net.SocketTimeoutException;
+import java.nio.channels.SocketChannel;
+/**
+ * A ServerSocket that can act either as a regular ServerSocket, as a SSLServerSocket, or as both, depending on
+ * the constructor parameters and on the type of client (TLS or plaintext) that connects to it.
+ * The constructors have the same signature as constructors of ServerSocket, with the addition of two parameters
+ * at the beginning:
+ * <ul>
+ * <li>X509Util - provides the SSL context to construct a secure socket when a client connects with TLS.</li>
+ * <li>boolean allowInsecureConnection - when true, acts as a hybrid server socket (plaintext / TLS). When
+ * false, acts as a SSLServerSocket (rejects plaintext connections).</li>
+ * </ul>
+ * The <code>!allowInsecureConnection</code> mode is needed so we can update the SSLContext (in particular, the
+ * key store and/or trust store) without having to re-create the server socket. By starting with a plaintext socket
+ * and delaying the upgrade to TLS until after a client has connected and begins a handshake, we can keep the same
+ * UnifiedServerSocket instance around, and replace the default SSLContext in the provided X509Util when the key store
+ * and/or trust store file changes on disk.
+ */
public class UnifiedServerSocket extends ServerSocket {
private static final Logger LOG = LoggerFactory.getLogger(UnifiedServerSocket.class);
private X509Util x509Util;
+ private final boolean allowInsecureConnection;
- public UnifiedServerSocket(X509Util x509Util) throws IOException {
+ /**
+ * Creates an unbound unified server socket by calling {@link ServerSocket#ServerSocket()}.
+ * Secure client connections will be upgraded to TLS once this socket detects the ClientHello message (start of a
+ * TLS handshake). Plaintext client connections will either be accepted or rejected depending on the value of
+ * the <code>allowInsecureConnection</code> parameter.
+ * @param x509Util the X509Util that provides the SSLContext to use for secure connections.
+ * @param allowInsecureConnection if true, accept plaintext connections, otherwise close them.
+ * @throws IOException if {@link ServerSocket#ServerSocket()} throws.
+ */
+ public UnifiedServerSocket(X509Util x509Util, boolean allowInsecureConnection) throws IOException {
super();
this.x509Util = x509Util;
+ this.allowInsecureConnection = allowInsecureConnection;
}
- public UnifiedServerSocket(X509Util x509Util, int port) throws IOException {
+ /**
+ * Creates a unified server socket bound to the specified port by calling {@link ServerSocket#ServerSocket(int)}.
+ * Secure client connections will be upgraded to TLS once this socket detects the ClientHello message (start of a
+ * TLS handshake). Plaintext client connections will either be accepted or rejected depending on the value of
+ * the <code>allowInsecureConnection</code> parameter.
+ * @param x509Util the X509Util that provides the SSLContext to use for secure connections.
+ * @param allowInsecureConnection if true, accept plaintext connections, otherwise close them.
+ * @param port the port number, or {@code 0} to use a port number that is automatically allocated.
+ * @throws IOException if {@link ServerSocket#ServerSocket(int)} throws.
+ */
+ public UnifiedServerSocket(X509Util x509Util, boolean allowInsecureConnection, int port) throws IOException {
super(port);
this.x509Util = x509Util;
+ this.allowInsecureConnection = allowInsecureConnection;
+ }
+
+ /**
+ * Creates a unified server socket bound to the specified port, with the specified backlog, by calling
+ * {@link ServerSocket#ServerSocket(int, int)}.
+ * Secure client connections will be upgraded to TLS once this socket detects the ClientHello message (start of a
+ * TLS handshake). Plaintext client connections will either be accepted or rejected depending on the value of
+ * the <code>allowInsecureConnection</code> parameter.
+ * @param x509Util the X509Util that provides the SSLContext to use for secure connections.
+ * @param allowInsecureConnection if true, accept plaintext connections, otherwise close them.
+ * @param port the port number, or {@code 0} to use a port number that is automatically allocated.
+ * @param backlog requested maximum length of the queue of incoming connections.
+ * @throws IOException if {@link ServerSocket#ServerSocket(int, int)} throws.
+ */
+ public UnifiedServerSocket(X509Util x509Util,
+ boolean allowInsecureConnection,
+ int port,
+ int backlog) throws IOException {
+ super(port, backlog);
+ this.x509Util = x509Util;
+ this.allowInsecureConnection = allowInsecureConnection;
+ }
+
+ /**
+ * Creates a unified server socket bound to the specified port, with the specified backlog, and local IP address
+ * to bind to, by calling {@link ServerSocket#ServerSocket(int, int, InetAddress)}.
+ * Secure client connections will be upgraded to TLS once this socket detects the ClientHello message (start of a
+ * TLS handshake). Plaintext client connections will either be accepted or rejected depending on the value of
+ * the <code>allowInsecureConnection</code> parameter.
+ * @param x509Util the X509Util that provides the SSLContext to use for secure connections.
+ * @param allowInsecureConnection if true, accept plaintext connections, otherwise close them.
+ * @param port the port number, or {@code 0} to use a port number that is automatically allocated.
+ * @param backlog requested maximum length of the queue of incoming connections.
+ * @param bindAddr the local InetAddress the server will bind to.
+ * @throws IOException if {@link ServerSocket#ServerSocket(int, int, InetAddress)} throws.
+ */
+ public UnifiedServerSocket(X509Util x509Util,
+ boolean allowInsecureConnection,
+ int port,
+ int backlog,
+ InetAddress bindAddr) throws IOException {
+ super(port, backlog, bindAddr);
+ this.x509Util = x509Util;
+ this.allowInsecureConnection = allowInsecureConnection;
}
@Override
@@ -56,24 +144,642 @@ public class UnifiedServerSocket extends ServerSocket {
}
final PrependableSocket prependableSocket = new PrependableSocket(null);
implAccept(prependableSocket);
+ return new UnifiedSocket(x509Util, allowInsecureConnection, prependableSocket);
+ }
+
+ /**
+ * The result of calling accept() on a UnifiedServerSocket. This is a Socket that doesn't know if it's
+ * using plaintext or SSL/TLS at the time when it is created. Calling a method that indicates a desire to
+ * read or write from the socket will cause the socket to detect if the connected client is attempting
+ * to establish a TLS or plaintext connection. This is done by doing a blocking read of 5 bytes off the
+ * socket and checking if the bytes look like the start of a TLS ClientHello message. If it looks like
+ * the client is attempting to connect with TLS, the internal socket is upgraded to a SSLSocket. If not,
+ * any bytes read from the socket are pushed back to the input stream, and the socket continues
+ * to be treated as a plaintext socket.
+ *
+ * The methods that trigger this behavior are:
+ * <ul>
+ * <li>{@link UnifiedSocket#getInputStream()}</li>
+ * <li>{@link UnifiedSocket#getOutputStream()}</li>
+ * <li>{@link UnifiedSocket#sendUrgentData(int)}</li>
+ * </ul>
+ *
+ * Calling other socket methods (i.e option setters such as {@link Socket#setTcpNoDelay(boolean)}) does
+ * not trigger mode detection.
+ *
+ * Because detecting the mode is a potentially blocking operation, it should not be done in the
+ * accepting thread. Attempting to read from or write to the socket in the accepting thread opens the
+ * caller up to a denial-of-service attack, in which a client connects and then does nothing. This would
+ * prevent any other clients from connecting. Passing the socket returned by accept() to a separate
+ * thread which handles all read and write operations protects against this DoS attack.
+ *
+ * Callers can check if the socket has been upgraded to TLS by calling {@link UnifiedSocket#isSecureSocket()},
+ * and can get the underlying SSLSocket by calling {@link UnifiedSocket#getSslSocket()}.
+ */
+ public static class UnifiedSocket extends Socket {
+ private enum Mode {
+ UNKNOWN,
+ PLAINTEXT,
+ TLS
+ }
- byte[] litmus = new byte[5];
- int bytesRead = prependableSocket.getInputStream().read(litmus, 0, 5);
- prependableSocket.prependToInputStream(litmus);
+ private final X509Util x509Util;
+ private final boolean allowInsecureConnection;
+ private PrependableSocket prependableSocket;
+ private SSLSocket sslSocket;
+ private Mode mode;
- if (bytesRead == 5 && SslHandler.isEncrypted(Unpooled.wrappedBuffer(litmus))) {
- LOG.info(getInetAddress() + " attempting to connect over ssl");
- SSLSocket sslSocket;
+ /**
+ * Note: this constructor is intentionally private. The only intended caller is
+ * {@link UnifiedServerSocket#accept()}.
+ *
+ * @param x509Util
+ * @param allowInsecureConnection
+ * @param prependableSocket
+ */
+ private UnifiedSocket(X509Util x509Util, boolean allowInsecureConnection, PrependableSocket prependableSocket) {
+ this.x509Util = x509Util;
+ this.allowInsecureConnection = allowInsecureConnection;
+ this.prependableSocket = prependableSocket;
+ this.sslSocket = null;
+ this.mode = Mode.UNKNOWN;
+ }
+
+ /**
+ * Returns true if the socket mode has been determined to be TLS.
+ * @return true if the mode is TLS, false if it is UNKNOWN or PLAINTEXT.
+ */
+ public boolean isSecureSocket() {
+ return mode == Mode.TLS;
+ }
+
+ /**
+ * Returns true if the socket mode has been determined to be PLAINTEXT.
+ * @return true if the mode is PLAINTEXT, false if it is UNKNOWN or TLS.
+ */
+ public boolean isPlaintextSocket() {
+ return mode == Mode.PLAINTEXT;
+ }
+
+ /**
+ * Returns true if the socket mode is not yet known.
+ * @return true if the mode is UNKNOWN, false if it is PLAINTEXT or TLS.
+ */
+ public boolean isModeKnown() {
+ return mode != Mode.UNKNOWN;
+ }
+
+ /**
+ * Detects the socket mode, see comments at the top of the class for more details. This operation will block
+ * for up to {@link X509Util#getSslHandshakeTimeoutMillis()} milliseconds and should not be called in the
+ * accept() thread if possible.
+ * @throws IOException
+ */
+ private void detectMode() throws IOException {
+ byte[] litmus = new byte[5];
+ int oldTimeout = -1;
+ int bytesRead = 0;
+ int newTimeout = x509Util.getSslHandshakeTimeoutMillis();
try {
- sslSocket = x509Util.createSSLSocket(prependableSocket);
- } catch (X509Exception e) {
- throw new IOException("failed to create SSL context", e);
+ oldTimeout = prependableSocket.getSoTimeout();
+ prependableSocket.setSoTimeout(newTimeout);
+ bytesRead = prependableSocket.getInputStream().read(litmus, 0, litmus.length);
+ } catch (SocketTimeoutException e) {
+ // Didn't read anything within the timeout, fallthrough and assume the connection is plaintext.
+ LOG.warn("Socket mode detection timed out after " + newTimeout + " ms, assuming PLAINTEXT");
+ } finally {
+ // restore socket timeout to the old value
+ try {
+ if (oldTimeout != -1) {
+ prependableSocket.setSoTimeout(oldTimeout);
+ }
+ } catch (Exception e) {
+ LOG.warn("Failed to restore old socket timeout value of " + oldTimeout + " ms", e);
+ }
+ }
+ if (bytesRead < 0) { // Got a EOF right away, definitely not using TLS. Fallthrough.
+ bytesRead = 0;
+ }
+
+ if (bytesRead == litmus.length && SslHandler.isEncrypted(Unpooled.wrappedBuffer(litmus))) {
+ try {
+ sslSocket = x509Util.createSSLSocket(prependableSocket, litmus);
+ } catch (X509Exception e) {
+ throw new IOException("failed to create SSL context", e);
+ }
+ prependableSocket = null;
+ mode = Mode.TLS;
+ } else if (allowInsecureConnection) {
+ prependableSocket.prependToInputStream(litmus, 0, bytesRead);
+ mode = Mode.PLAINTEXT;
+ } else {
+ prependableSocket.close();
+ mode = Mode.PLAINTEXT;
+ throw new IOException("Blocked insecure connection attempt");
+ }
+ }
+
+ private Socket getSocketAllowUnknownMode() {
+ if (isSecureSocket()) {
+ return sslSocket;
+ } else { // Note: mode is UNKNOWN or PLAINTEXT
+ return prependableSocket;
+ }
+ }
+
+ /**
+ * Returns the underlying socket, detecting the socket mode if it is not yet known. This is a potentially
+ * blocking operation and should not be called in the accept() thread.
+ * @return the underlying socket, after the socket mode has been determined.
+ * @throws IOException
+ */
+ private Socket getSocket() throws IOException {
+ if (!isModeKnown()) {
+ detectMode();
+ }
+ if (mode == Mode.TLS) {
+ return sslSocket;
+ } else {
+ return prependableSocket;
+ }
+ }
+
+ /**
+ * Returns the underlying SSLSocket if the mode is TLS. If the mode is UNKNOWN, causes mode detection which is a
+ * potentially blocking operation. If the mode ends up being PLAINTEXT, this will throw a SocketException, so
+ * callers are advised to only call this method after checking that {@link UnifiedSocket#isSecureSocket()}
+ * returned true.
+ * @return the underlying SSLSocket if the mode is known to be TLS.
+ * @throws IOException if detecting the socket mode fails
+ * @throws SocketException if the mode is PLAINTEXT.
+ */
+ public SSLSocket getSslSocket() throws IOException {
+ if (!isModeKnown()) {
+ detectMode();
+ }
+ if (!isSecureSocket()) {
+ throw new SocketException("Socket mode is not TLS");
}
- sslSocket.setUseClientMode(false);
return sslSocket;
- } else {
- LOG.info(getInetAddress() + " attempting to connect without ssl");
- return prependableSocket;
}
+
+ /**
+ * See {@link Socket#connect(SocketAddress)}. Calling this method does not trigger mode detection.
+ */
+ @Override
+ public void connect(SocketAddress endpoint) throws IOException {
+ getSocketAllowUnknownMode().connect(endpoint);
+ }
+
+ /**
+ * See {@link Socket#connect(SocketAddress, int)}. Calling this method does not trigger mode detection.
+ */
+ @Override
+ public void connect(SocketAddress endpoint, int timeout) throws IOException {
+ getSocketAllowUnknownMode().connect(endpoint, timeout);
+ }
+
+ /**
+ * See {@link Socket#bind(SocketAddress)}. Calling this method does not trigger mode detection.
+ */
+ @Override
+ public void bind(SocketAddress bindpoint) throws IOException {
+ getSocketAllowUnknownMode().bind(bindpoint);
+ }
+
+ /**
+ * See {@link Socket#getInetAddress()}. Calling this method does not trigger mode detection.
+ */
+ @Override
+ public InetAddress getInetAddress() {
+ return getSocketAllowUnknownMode().getInetAddress();
+ }
+
+ /**
+ * See {@link Socket#getLocalAddress()}. Calling this method does not trigger mode detection.
+ */
+ @Override
+ public InetAddress getLocalAddress() {
+ return getSocketAllowUnknownMode().getLocalAddress();
+ }
+
+ /**
+ * See {@link Socket#getPort()}. Calling this method does not trigger mode detection.
+ */
+ @Override
+ public int getPort() {
+ return getSocketAllowUnknownMode().getPort();
+ }
+
+ /**
+ * See {@link Socket#getLocalPort()}. Calling this method does not trigger mode detection.
+ */
+ @Override
+ public int getLocalPort() {
+ return getSocketAllowUnknownMode().getLocalPort();
+ }
+
+ /**
+ * See {@link Socket#getRemoteSocketAddress()}. Calling this method does not trigger mode detection.
+ */
+ @Override
+ public SocketAddress getRemoteSocketAddress() {
+ return getSocketAllowUnknownMode().getRemoteSocketAddress();
+ }
+
+ /**
+ * See {@link Socket#getLocalSocketAddress()}. Calling this method does not trigger mode detection.
+ */
+ @Override
+ public SocketAddress getLocalSocketAddress() {
+ return getSocketAllowUnknownMode().getLocalSocketAddress();
+ }
+
+ /**
+ * See {@link Socket#getChannel()}. Calling this method does not trigger mode detection.
+ */
+ @Override
+ public SocketChannel getChannel() {
+ return getSocketAllowUnknownMode().getChannel();
+ }
+
+ /**
+ * See {@link Socket#getInputStream()}. If the socket mode has not yet been detected, the first read from the
+ * returned input stream will trigger mode detection, which is a potentially blocking operation. This means
+ * the accept() thread should avoid reading from this input stream if possible.
+ */
+ @Override
+ public InputStream getInputStream() throws IOException {
+ return new UnifiedInputStream(this);
+ }
+
+ /**
+ * See {@link Socket#getOutputStream()}. If the socket mode has not yet been detected, the first read from the
+ * returned input stream will trigger mode detection, which is a potentially blocking operation. This means
+ * the accept() thread should avoid reading from this input stream if possible.
+ */
+ @Override
+ public OutputStream getOutputStream() throws IOException {
+ return new UnifiedOutputStream(this);
+ }
+
+ /**
+ * See {@link Socket#setTcpNoDelay(boolean)}. Calling this method does not trigger mode detection.
+ */
+ @Override
+ public void setTcpNoDelay(boolean on) throws SocketException {
+ getSocketAllowUnknownMode().setTcpNoDelay(on);
+ }
+
+ /**
+ * See {@link Socket#getTcpNoDelay()}. Calling this method does not trigger mode detection.
+ */
+ @Override
+ public boolean getTcpNoDelay() throws SocketException {
+ return getSocketAllowUnknownMode().getTcpNoDelay();
+ }
+
+ /**
+ * See {@link Socket#setSoLinger(boolean, int)}. Calling this method does not trigger mode detection.
+ */
+ @Override
+ public void setSoLinger(boolean on, int linger) throws SocketException {
+ getSocketAllowUnknownMode().setSoLinger(on, linger);
+ }
+
+ /**
+ * See {@link Socket#getSoLinger()}. Calling this method does not trigger mode detection.
+ */
+ @Override
+ public int getSoLinger() throws SocketException {
+ return getSocketAllowUnknownMode().getSoLinger();
+ }
+
+ /**
+ * See {@link Socket#sendUrgentData(int)}. Calling this method triggers mode detection, which is a potentially
+ * blocking operation, so it should not be done in the accept() thread.
+ */
+ @Override
+ public void sendUrgentData(int data) throws IOException {
+ getSocket().sendUrgentData(data);
+ }
+
+ /**
+ * See {@link Socket#setOOBInline(boolean)}. Calling this method does not trigger mode detection.
+ */
+ @Override
+ public void setOOBInline(boolean on) throws SocketException {
+ getSocketAllowUnknownMode().setOOBInline(on);
+ }
+
+ /**
+ * See {@link Socket#getOOBInline()}. Calling this method does not trigger mode detection.
+ */
+ @Override
+ public boolean getOOBInline() throws SocketException {
+ return getSocketAllowUnknownMode().getOOBInline();
+ }
+
+ /**
+ * See {@link Socket#setSoTimeout(int)}. Calling this method does not trigger mode detection.
+ */
+ @Override
+ public synchronized void setSoTimeout(int timeout) throws SocketException {
+ getSocketAllowUnknownMode().setSoTimeout(timeout);
+ }
+
+ /**
+ * See {@link Socket#getSoTimeout()}. Calling this method does not trigger mode detection.
+ */
+ @Override
+ public synchronized int getSoTimeout() throws SocketException {
+ return getSocketAllowUnknownMode().getSoTimeout();
+ }
+
+ /**
+ * See {@link Socket#setSendBufferSize(int)}. Calling this method does not trigger mode detection.
+ */
+ @Override
+ public synchronized void setSendBufferSize(int size) throws SocketException {
+ getSocketAllowUnknownMode().setSendBufferSize(size);
+ }
+
+ /**
+ * See {@link Socket#getSendBufferSize()}. Calling this method does not trigger mode detection.
+ */
+ @Override
+ public synchronized int getSendBufferSize() throws SocketException {
+ return getSocketAllowUnknownMode().getSendBufferSize();
+ }
+
+ /**
+ * See {@link Socket#setReceiveBufferSize(int)}. Calling this method does not trigger mode detection.
+ */
+ @Override
+ public synchronized void setReceiveBufferSize(int size) throws SocketException {
+ getSocketAllowUnknownMode().setReceiveBufferSize(size);
+ }
+
+ /**
+ * See {@link Socket#getReceiveBufferSize()}. Calling this method does not trigger mode detection.
+ */
+ @Override
+ public synchronized int getReceiveBufferSize() throws SocketException {
+ return getSocketAllowUnknownMode().getReceiveBufferSize();
+ }
+
+ /**
+ * See {@link Socket#setKeepAlive(boolean)}. Calling this method does not trigger mode detection.
+ */
+ @Override
+ public void setKeepAlive(boolean on) throws SocketException {
+ getSocketAllowUnknownMode().setKeepAlive(on);
+ }
+
+ /**
+ * See {@link Socket#getKeepAlive()}. Calling this method does not trigger mode detection.
+ */
+ @Override
+ public boolean getKeepAlive() throws SocketException {
+ return getSocketAllowUnknownMode().getKeepAlive();
+ }
+
+ /**
+ * See {@link Socket#setTrafficClass(int)}. Calling this method does not trigger mode detection.
+ */
+ @Override
+ public void setTrafficClass(int tc) throws SocketException {
+ getSocketAllowUnknownMode().setTrafficClass(tc);
+ }
+
+ /**
+ * See {@link Socket#getTrafficClass()}. Calling this method does not trigger mode detection.
+ */
+ @Override
+ public int getTrafficClass() throws SocketException {
+ return getSocketAllowUnknownMode().getTrafficClass();
+ }
+
+ /**
+ * See {@link Socket#setReuseAddress(boolean)}. Calling this method does not trigger mode detection.
+ */
+ @Override
+ public void setReuseAddress(boolean on) throws SocketException {
+ getSocketAllowUnknownMode().setReuseAddress(on);
+ }
+
+ /**
+ * See {@link Socket#getReuseAddress()}. Calling this method does not trigger mode detection.
+ */
+ @Override
+ public boolean getReuseAddress() throws SocketException {
+ return getSocketAllowUnknownMode().getReuseAddress();
+ }
+
+ /**
+ * See {@link Socket#close()}. Calling this method does not trigger mode detection.
+ */
+ @Override
+ public synchronized void close() throws IOException {
+ getSocketAllowUnknownMode().close();
+ }
+
+ /**
+ * See {@link Socket#shutdownInput()}. Calling this method does not trigger mode detection.
+ */
+ @Override
+ public void shutdownInput() throws IOException {
+ getSocketAllowUnknownMode().shutdownInput();
+ }
+
+ /**
+ * See {@link Socket#shutdownOutput()}. Calling this method does not trigger mode detection.
+ */
+ @Override
+ public void shutdownOutput() throws IOException {
+ getSocketAllowUnknownMode().shutdownOutput();
+ }
+
+ /**
+ * See {@link Socket#toString()}. Calling this method does not trigger mode detection.
+ */
+ @Override
+ public String toString() {
+ return "UnifiedSocket[mode=" + mode.toString() + "socket=" + getSocketAllowUnknownMode().toString() + "]";
+ }
+
+ /**
+ * See {@link Socket#isConnected()}. Calling this method does not trigger mode detection.
+ */
+ @Override
+ public boolean isConnected() {
+ return getSocketAllowUnknownMode().isConnected();
+ }
+
+ /**
+ * See {@link Socket#isBound()}. Calling this method does not trigger mode detection.
+ */
+ @Override
+ public boolean isBound() {
+ return getSocketAllowUnknownMode().isBound();
+ }
+
+ /**
+ * See {@link Socket#isClosed()}. Calling this method does not trigger mode detection.
+ */
+ @Override
+ public boolean isClosed() {
+ return getSocketAllowUnknownMode().isClosed();
+ }
+
+ /**
+ * See {@link Socket#isInputShutdown()}. Calling this method does not trigger mode detection.
+ */
+ @Override
+ public boolean isInputShutdown() {
+ return getSocketAllowUnknownMode().isInputShutdown();
+ }
+
+ /**
+ * See {@link Socket#isOutputShutdown()}. Calling this method does not trigger mode detection.
+ */
+ @Override
+ public boolean isOutputShutdown() {
+ return getSocketAllowUnknownMode().isOutputShutdown();
+ }
+
+ /**
+ * See {@link Socket#setPerformancePreferences(int, int, int)}. Calling this method does not trigger
+ * mode detection.
+ */
+ @Override
+ public void setPerformancePreferences(int connectionTime, int latency, int bandwidth) {
+ getSocketAllowUnknownMode().setPerformancePreferences(connectionTime, latency, bandwidth);
+ }
+ }
+
+ /**
+ * An input stream for a UnifiedSocket. The first read from this stream will trigger mode detection on the
+ * underlying UnifiedSocket.
+ */
+ private static class UnifiedInputStream extends InputStream {
+ private final UnifiedSocket unifiedSocket;
+ private InputStream realInputStream;
+
+ private UnifiedInputStream(UnifiedSocket unifiedSocket) {
+ this.unifiedSocket = unifiedSocket;
+ this.realInputStream = null;
+ }
+
+ @Override
+ public int read() throws IOException {
+ return getRealInputStream().read();
+ }
+
+ /**
+ * Note: SocketInputStream has optimized implementations of bulk-read operations, so we need to call them
+ * directly instead of relying on the base-class implementation which just calls the single-byte read() over
+ * and over. Not implementing these results in awful performance.
+ */
+ @Override
+ public int read(byte[] b) throws IOException {
+ return getRealInputStream().read(b);
+ }
+
+ @Override
+ public int read(byte[] b, int off, int len) throws IOException {
+ return getRealInputStream().read(b, off, len);
+ }
+
+ private InputStream getRealInputStream() throws IOException {
+ if (realInputStream == null) {
+ // Note: The first call to getSocket() triggers mode detection which can block
+ realInputStream = unifiedSocket.getSocket().getInputStream();
+ }
+ return realInputStream;
+ }
+
+ @Override
+ public long skip(long n) throws IOException {
+ return getRealInputStream().skip(n);
+ }
+
+ @Override
+ public int available() throws IOException {
+ return getRealInputStream().available();
+ }
+
+ @Override
+ public void close() throws IOException {
+ getRealInputStream().close();
+ }
+
+ @Override
+ public synchronized void mark(int readlimit) {
+ try {
+ getRealInputStream().mark(readlimit);
+ } catch (IOException e) {
+ throw new RuntimeException(e);
+ }
+ }
+
+ @Override
+ public synchronized void reset() throws IOException {
+ getRealInputStream().reset();
+ }
+
+ @Override
+ public boolean markSupported() {
+ try {
+ return getRealInputStream().markSupported();
+ } catch (IOException e) {
+ throw new RuntimeException(e);
+ }
+ }
+
+ }
+
+ private static class UnifiedOutputStream extends OutputStream {
+ private final UnifiedSocket unifiedSocket;
+ private OutputStream realOutputStream;
+
+ private UnifiedOutputStream(UnifiedSocket unifiedSocket) {
+ this.unifiedSocket = unifiedSocket;
+ this.realOutputStream = null;
+ }
+
+ @Override
+ public void write(int b) throws IOException {
+ getRealOutputStream().write(b);
+ }
+
+ @Override
+ public void write(byte[] b) throws IOException {
+ getRealOutputStream().write(b);
+ }
+
+ @Override
+ public void write(byte[] b, int off, int len) throws IOException {
+ getRealOutputStream().write(b, off, len);
+ }
+
+ @Override
+ public void flush() throws IOException {
+ getRealOutputStream().flush();
+ }
+
+ @Override
+ public void close() throws IOException {
+ getRealOutputStream().close();
+ }
+
+ private OutputStream getRealOutputStream() throws IOException {
+ if (realOutputStream == null) {
+ // Note: The first call to getSocket() triggers mode detection which can block
+ realOutputStream = unifiedSocket.getSocket().getOutputStream();
+ }
+ return realOutputStream;
+ }
+
}
-}
\ No newline at end of file
+}
http://git-wip-us.apache.org/repos/asf/zookeeper/blob/64104eae/zookeeper-server/src/test/java/org/apache/zookeeper/common/X509UtilTest.java
----------------------------------------------------------------------
diff --git a/zookeeper-server/src/test/java/org/apache/zookeeper/common/X509UtilTest.java b/zookeeper-server/src/test/java/org/apache/zookeeper/common/X509UtilTest.java
index 6b343c3..546cf55 100644
--- a/zookeeper-server/src/test/java/org/apache/zookeeper/common/X509UtilTest.java
+++ b/zookeeper-server/src/test/java/org/apache/zookeeper/common/X509UtilTest.java
@@ -356,6 +356,34 @@ public class X509UtilTest extends BaseX509ParameterizedTestCase {
true);
}
+ @Test
+ public void testGetSslHandshakeDetectionTimeoutMillisProperty() {
+ X509Util x509Util = new ClientX509Util();
+ Assert.assertEquals(
+ X509Util.DEFAULT_HANDSHAKE_DETECTION_TIMEOUT_MILLIS,
+ x509Util.getSslHandshakeTimeoutMillis());
+ try {
+ String newPropertyString = Integer.toString(X509Util.DEFAULT_HANDSHAKE_DETECTION_TIMEOUT_MILLIS + 1);
+ System.setProperty(x509Util.getSslHandshakeDetectionTimeoutMillisProperty(), newPropertyString);
+ // Note: need to create a new ClientX509Util to pick up modified property value
+ Assert.assertEquals(
+ X509Util.DEFAULT_HANDSHAKE_DETECTION_TIMEOUT_MILLIS + 1,
+ new ClientX509Util().getSslHandshakeTimeoutMillis());
+ // 0 value not allowed, will return the default
+ System.setProperty(x509Util.getSslHandshakeDetectionTimeoutMillisProperty(), "0");
+ Assert.assertEquals(
+ X509Util.DEFAULT_HANDSHAKE_DETECTION_TIMEOUT_MILLIS,
+ new ClientX509Util().getSslHandshakeTimeoutMillis());
+ // Negative value not allowed, will return the default
+ System.setProperty(x509Util.getSslHandshakeDetectionTimeoutMillisProperty(), "-1");
+ Assert.assertEquals(
+ X509Util.DEFAULT_HANDSHAKE_DETECTION_TIMEOUT_MILLIS,
+ new ClientX509Util().getSslHandshakeTimeoutMillis());
+ } finally {
+ System.clearProperty(x509Util.getSslHandshakeDetectionTimeoutMillisProperty());
+ }
+ }
+
// Warning: this will reset the x509Util
private void setCustomCipherSuites() {
System.setProperty(x509Util.getCipherSuitesProperty(), customCipherSuites[0] + "," + customCipherSuites[1]);
http://git-wip-us.apache.org/repos/asf/zookeeper/blob/64104eae/zookeeper-server/src/test/java/org/apache/zookeeper/server/quorum/QuorumSSLTest.java
----------------------------------------------------------------------
diff --git a/zookeeper-server/src/test/java/org/apache/zookeeper/server/quorum/QuorumSSLTest.java b/zookeeper-server/src/test/java/org/apache/zookeeper/server/quorum/QuorumSSLTest.java
index b088f47..67c15ad 100644
--- a/zookeeper-server/src/test/java/org/apache/zookeeper/server/quorum/QuorumSSLTest.java
+++ b/zookeeper-server/src/test/java/org/apache/zookeeper/server/quorum/QuorumSSLTest.java
@@ -80,7 +80,6 @@ import org.bouncycastle.util.io.pem.PemWriter;
import org.junit.After;
import org.junit.Assert;
import org.junit.Before;
-import org.junit.Ignore;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.Timeout;
@@ -442,7 +441,6 @@ public class QuorumSSLTest extends QuorumPeerTestBase {
Assert.assertFalse(ClientBase.waitForServerUp("127.0.0.1:" + clientPortQp3, CONNECTION_TIMEOUT));
}
- @Ignore("portUnification is currently broken and disabled")
@Test
public void testRollingUpgrade() throws Exception {
// Form a quorum without ssl
http://git-wip-us.apache.org/repos/asf/zookeeper/blob/64104eae/zookeeper-server/src/test/java/org/apache/zookeeper/server/quorum/UnifiedServerSocketModeDetectionTest.java
----------------------------------------------------------------------
diff --git a/zookeeper-server/src/test/java/org/apache/zookeeper/server/quorum/UnifiedServerSocketModeDetectionTest.java b/zookeeper-server/src/test/java/org/apache/zookeeper/server/quorum/UnifiedServerSocketModeDetectionTest.java
new file mode 100644
index 0000000..61862a4
--- /dev/null
+++ b/zookeeper-server/src/test/java/org/apache/zookeeper/server/quorum/UnifiedServerSocketModeDetectionTest.java
@@ -0,0 +1,404 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.zookeeper.server.quorum;
+
+import java.io.File;
+import java.io.IOException;
+import java.net.InetAddress;
+import java.net.InetSocketAddress;
+import java.net.ServerSocket;
+import java.net.Socket;
+import java.net.SocketOptions;
+import java.security.Security;
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.concurrent.Callable;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
+import java.util.concurrent.Future;
+import java.util.concurrent.TimeUnit;
+
+import org.apache.commons.io.FileUtils;
+import org.apache.zookeeper.PortAssignment;
+import org.apache.zookeeper.ZKTestCase;
+import org.apache.zookeeper.common.ClientX509Util;
+import org.apache.zookeeper.common.KeyStoreFileType;
+import org.apache.zookeeper.common.X509KeyType;
+import org.apache.zookeeper.common.X509TestContext;
+import org.apache.zookeeper.common.X509Util;
+import org.apache.zookeeper.test.ClientBase;
+import org.bouncycastle.jce.provider.BouncyCastleProvider;
+import org.junit.After;
+import org.junit.AfterClass;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.BeforeClass;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * This test makes sure that certain operations on a UnifiedServerSocket do not
+ * trigger blocking mode detection. This is necessary to ensure that the
+ * Leader's accept() thread doesn't get blocked.
+ */
+@RunWith(Parameterized.class)
+public class UnifiedServerSocketModeDetectionTest extends ZKTestCase {
+ private static final Logger LOG = LoggerFactory.getLogger(
+ UnifiedServerSocketModeDetectionTest.class);
+
+ @Parameterized.Parameters
+ public static Collection<Object[]> params() {
+ ArrayList<Object[]> result = new ArrayList<>();
+ result.add(new Object[] { true });
+ result.add(new Object[] { false });
+ return result;
+ }
+
+ private static File tempDir;
+ private static X509TestContext x509TestContext;
+
+ private boolean useSecureClient;
+ private X509Util x509Util;
+ private UnifiedServerSocket listeningSocket;
+ private UnifiedServerSocket.UnifiedSocket serverSideSocket;
+ private Socket clientSocket;
+ private ExecutorService workerPool;
+ private int port;
+ private InetSocketAddress localServerAddress;
+
+ @BeforeClass
+ public static void setUpClass() throws Exception {
+ Security.addProvider(new BouncyCastleProvider());
+ tempDir = ClientBase.createEmptyTestDir();
+ x509TestContext = X509TestContext.newBuilder()
+ .setTempDir(tempDir)
+ .setKeyStoreKeyType(X509KeyType.EC)
+ .setTrustStoreKeyType(X509KeyType.EC)
+ .build();
+ }
+
+ @AfterClass
+ public static void tearDownClass() {
+ try {
+ FileUtils.deleteDirectory(tempDir);
+ } catch (IOException e) {
+ // ignore
+ }
+ Security.removeProvider(BouncyCastleProvider.PROVIDER_NAME);
+ }
+
+ private static void forceClose(Socket s) {
+ if (s == null || s.isClosed()) {
+ return;
+ }
+ try {
+ s.close();
+ } catch (IOException e) {
+ }
+ }
+
+ private static void forceClose(ServerSocket s) {
+ if (s == null || s.isClosed()) {
+ return;
+ }
+ try {
+ s.close();
+ } catch (IOException e) {
+ }
+ }
+
+ public UnifiedServerSocketModeDetectionTest(Boolean useSecureClient) {
+ this.useSecureClient = useSecureClient;
+ }
+
+ @Before
+ public void setUp() throws Exception {
+ x509Util = new ClientX509Util();
+ x509TestContext.setSystemProperties(x509Util, KeyStoreFileType.JKS, KeyStoreFileType.JKS);
+ System.setProperty(x509Util.getSslHandshakeDetectionTimeoutMillisProperty(), "100");
+ workerPool = Executors.newCachedThreadPool();
+ port = PortAssignment.unique();
+ localServerAddress = new InetSocketAddress(InetAddress.getLoopbackAddress(), port);
+ listeningSocket = new UnifiedServerSocket(x509Util, true);
+ listeningSocket.bind(localServerAddress);
+ Future<UnifiedServerSocket.UnifiedSocket> acceptFuture;
+ acceptFuture = workerPool.submit(new Callable<UnifiedServerSocket.UnifiedSocket>() {
+ @Override
+ public UnifiedServerSocket.UnifiedSocket call() throws Exception {
+ try {
+ return (UnifiedServerSocket.UnifiedSocket) listeningSocket.accept();
+ } catch (IOException e) {
+ LOG.error("Error in accept(): ", e);
+ throw e;
+ }
+ }
+ });
+ if (useSecureClient) {
+ clientSocket = x509Util.createSSLSocket();
+ clientSocket.connect(localServerAddress);
+ } else {
+ clientSocket = new Socket();
+ clientSocket.connect(localServerAddress);
+ clientSocket.getOutputStream().write(new byte[] { 1, 2, 3, 4, 5 });
+ }
+ serverSideSocket = acceptFuture.get();
+ }
+
+ @After
+ public void tearDown() throws Exception {
+ x509TestContext.clearSystemProperties(x509Util);
+ System.clearProperty(x509Util.getSslHandshakeDetectionTimeoutMillisProperty());
+ forceClose(listeningSocket);
+ forceClose(serverSideSocket);
+ forceClose(clientSocket);
+ workerPool.shutdown();
+ workerPool.awaitTermination(1000, TimeUnit.MILLISECONDS);
+ }
+
+ @Test
+ public void testGetInetAddress() {
+ serverSideSocket.getInetAddress();
+ Assert.assertFalse(serverSideSocket.isModeKnown());
+ }
+
+ @Test
+ public void testGetLocalAddress() {
+ serverSideSocket.getLocalAddress();
+ Assert.assertFalse(serverSideSocket.isModeKnown());
+ }
+
+ @Test
+ public void testGetPort() {
+ serverSideSocket.getPort();
+ Assert.assertFalse(serverSideSocket.isModeKnown());
+ }
+
+ @Test
+ public void testGetLocalPort() {
+ serverSideSocket.getLocalPort();
+ Assert.assertFalse(serverSideSocket.isModeKnown());
+ }
+
+ @Test
+ public void testGetRemoteSocketAddress() {
+ serverSideSocket.getRemoteSocketAddress();
+ Assert.assertFalse(serverSideSocket.isModeKnown());
+ }
+
+ @Test
+ public void testGetLocalSocketAddress() {
+ serverSideSocket.getLocalSocketAddress();
+ Assert.assertFalse(serverSideSocket.isModeKnown());
+ }
+
+ @Test
+ public void testGetInputStream() throws IOException {
+ serverSideSocket.getInputStream();
+ Assert.assertFalse(serverSideSocket.isModeKnown());
+ }
+
+ @Test
+ public void testGetOutputStream() throws IOException {
+ serverSideSocket.getOutputStream();
+ Assert.assertFalse(serverSideSocket.isModeKnown());
+ }
+
+ @Test
+ public void testGetTcpNoDelay() throws IOException {
+ serverSideSocket.getTcpNoDelay();
+ Assert.assertFalse(serverSideSocket.isModeKnown());
+ }
+
+ @Test
+ public void testSetTcpNoDelay() throws IOException {
+ boolean tcpNoDelay = serverSideSocket.getTcpNoDelay();
+ tcpNoDelay = !tcpNoDelay;
+ serverSideSocket.setTcpNoDelay(tcpNoDelay);
+ Assert.assertFalse(serverSideSocket.isModeKnown());
+ Assert.assertEquals(tcpNoDelay, serverSideSocket.getTcpNoDelay());
+ }
+
+ @Test
+ public void testGetSoLinger() throws IOException {
+ serverSideSocket.getSoLinger();
+ Assert.assertFalse(serverSideSocket.isModeKnown());
+ }
+
+ @Test
+ public void testSetSoLinger() throws IOException {
+ int soLinger = serverSideSocket.getSoLinger();
+ if (soLinger == -1) {
+ // enable it if disabled
+ serverSideSocket.setSoLinger(true, 1);
+ Assert.assertFalse(serverSideSocket.isModeKnown());
+ Assert.assertEquals(1, serverSideSocket.getSoLinger());
+ } else {
+ // disable it if enabled
+ serverSideSocket.setSoLinger(false, -1);
+ Assert.assertFalse(serverSideSocket.isModeKnown());
+ Assert.assertEquals(-1, serverSideSocket.getSoLinger());
+ }
+ }
+
+ @Test
+ public void testGetSoTimeout() throws IOException {
+ serverSideSocket.getSoTimeout();
+ Assert.assertFalse(serverSideSocket.isModeKnown());
+ }
+
+ @Test
+ public void testSetSoTimeout() throws IOException {
+ int timeout = serverSideSocket.getSoTimeout();
+ timeout = timeout + 10;
+ serverSideSocket.setSoTimeout(timeout);
+ Assert.assertFalse(serverSideSocket.isModeKnown());
+ Assert.assertEquals(timeout, serverSideSocket.getSoTimeout());
+ }
+
+ @Test
+ public void testGetSendBufferSize() throws IOException {
+ serverSideSocket.getSendBufferSize();
+ Assert.assertFalse(serverSideSocket.isModeKnown());
+ }
+
+ @Test
+ public void testSetSendBufferSize() throws IOException {
+ serverSideSocket.setSendBufferSize(serverSideSocket.getSendBufferSize() + 1024);
+ Assert.assertFalse(serverSideSocket.isModeKnown());
+ // Note: the new buffer size is a hint and socket implementation
+ // is free to ignore it, so we don't verify that we get back the
+ // same value.
+
+ }
+
+ @Test
+ public void testGetReceiveBufferSize() throws IOException {
+ serverSideSocket.getReceiveBufferSize();
+ Assert.assertFalse(serverSideSocket.isModeKnown());
+ }
+
+ @Test
+ public void testSetReceiveBufferSize() throws IOException {
+ serverSideSocket.setReceiveBufferSize(serverSideSocket.getReceiveBufferSize() + 1024);
+ Assert.assertFalse(serverSideSocket.isModeKnown());
+ // Note: the new buffer size is a hint and socket implementation
+ // is free to ignore it, so we don't verify that we get back the
+ // same value.
+
+ }
+
+ @Test
+ public void testGetKeepAlive() throws IOException {
+ serverSideSocket.getKeepAlive();
+ Assert.assertFalse(serverSideSocket.isModeKnown());
+ }
+
+ @Test
+ public void testSetKeepAlive() throws IOException {
+ boolean keepAlive = serverSideSocket.getKeepAlive();
+ keepAlive = !keepAlive;
+ serverSideSocket.setKeepAlive(keepAlive);
+ Assert.assertFalse(serverSideSocket.isModeKnown());
+ Assert.assertEquals(keepAlive, serverSideSocket.getKeepAlive());
+ }
+
+ @Test
+ public void testGetTrafficClass() throws IOException {
+ serverSideSocket.getTrafficClass();
+ Assert.assertFalse(serverSideSocket.isModeKnown());
+ }
+
+ @Test
+ public void testSetTrafficClass() throws IOException {
+ serverSideSocket.setTrafficClass(SocketOptions.IP_TOS);
+ Assert.assertFalse(serverSideSocket.isModeKnown());
+ // Note: according to the Socket javadocs, setTrafficClass() may be
+ // ignored by socket implementations, so we don't check that the value
+ // we set is returned.
+ }
+
+ @Test
+ public void testGetReuseAddress() throws IOException {
+ serverSideSocket.getReuseAddress();
+ Assert.assertFalse(serverSideSocket.isModeKnown());
+ }
+
+ @Test
+ public void testSetReuseAddress() throws IOException {
+ boolean reuseAddress = serverSideSocket.getReuseAddress();
+ reuseAddress = !reuseAddress;
+ serverSideSocket.setReuseAddress(reuseAddress);
+ Assert.assertFalse(serverSideSocket.isModeKnown());
+ Assert.assertEquals(reuseAddress, serverSideSocket.getReuseAddress());
+ }
+
+ @Test
+ public void testClose() throws IOException {
+ serverSideSocket.close();
+ Assert.assertFalse(serverSideSocket.isModeKnown());
+ }
+
+ @Test
+ public void testShutdownInput() throws IOException {
+ serverSideSocket.shutdownInput();
+ Assert.assertFalse(serverSideSocket.isModeKnown());
+ }
+
+ @Test
+ public void testShutdownOutput() throws IOException {
+ serverSideSocket.shutdownOutput();
+ Assert.assertFalse(serverSideSocket.isModeKnown());
+ }
+
+ @Test
+ public void testIsConnected() {
+ serverSideSocket.isConnected();
+ Assert.assertFalse(serverSideSocket.isModeKnown());
+ }
+
+ @Test
+ public void testIsBound() {
+ serverSideSocket.isBound();
+ Assert.assertFalse(serverSideSocket.isModeKnown());
+ }
+
+ @Test
+ public void testIsClosed() {
+ serverSideSocket.isClosed();
+ Assert.assertFalse(serverSideSocket.isModeKnown());
+ }
+
+ @Test
+ public void testIsInputShutdown() throws IOException {
+ serverSideSocket.isInputShutdown();
+ Assert.assertFalse(serverSideSocket.isModeKnown());
+ serverSideSocket.shutdownInput();
+ Assert.assertTrue(serverSideSocket.isInputShutdown());
+ }
+
+ @Test
+ public void testIsOutputShutdown() throws IOException {
+ serverSideSocket.isOutputShutdown();
+ Assert.assertFalse(serverSideSocket.isModeKnown());
+ serverSideSocket.shutdownOutput();
+ Assert.assertTrue(serverSideSocket.isOutputShutdown());
+ }
+}