You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@zookeeper.apache.org by an...@apache.org on 2018/11/27 09:02:32 UTC

[2/2] zookeeper git commit: ZOOKEEPER-3172: Quorum TLS - fix port unification to allow rolling upgrades

ZOOKEEPER-3172: Quorum TLS - fix port unification to allow rolling upgrades

Fix numerous problems with UnifiedServerSocket, such as hanging the accept() thread when the client doesn't send any data or crashing if less than 5 bytes are read from the socket in the initial read.

Re-enable the "portUnification" config option.

## Fixed networking issues/bugs in UnifiedServerSocket

- don't crash the `accept()` thread if the client closes the connection without sending any data
- don't corrupt the connection if the client sends fewer than 5 bytes for the initial read
- delay the detection of TLS vs. plaintext mode until a socket stream is read from or written to. This prevents the `accept()` thread from getting blocked on a `read()` operation from the newly connected socket.
- prepending 5 bytes to `PrependableSocket` and then trying to read >5 bytes would only return the first 5 bytes, even if more bytes were available. This is fixed.

Author: Ilya Maykov <il...@fb.com>

Reviewers: andor@apache.org

Closes #679 from ivmaykov/ZOOKEEPER-3172


Project: http://git-wip-us.apache.org/repos/asf/zookeeper/repo
Commit: http://git-wip-us.apache.org/repos/asf/zookeeper/commit/64104eae
Tree: http://git-wip-us.apache.org/repos/asf/zookeeper/tree/64104eae
Diff: http://git-wip-us.apache.org/repos/asf/zookeeper/diff/64104eae

Branch: refs/heads/master
Commit: 64104eaeaa6508f052edfd39c24243a8e26039dc
Parents: 91c6cb2
Author: Ilya Maykov <il...@fb.com>
Authored: Tue Nov 27 10:02:24 2018 +0100
Committer: Andor Molnar <an...@apache.org>
Committed: Tue Nov 27 10:02:24 2018 +0100

----------------------------------------------------------------------
 .../org/apache/zookeeper/common/X509Util.java   |  55 +-
 .../org/apache/zookeeper/common/ZKConfig.java   |   2 +
 .../apache/zookeeper/server/quorum/Leader.java  |  29 +-
 .../apache/zookeeper/server/quorum/Learner.java |   9 +-
 .../server/quorum/PrependableSocket.java        |  29 +-
 .../server/quorum/QuorumCnxManager.java         |  34 +-
 .../zookeeper/server/quorum/QuorumPeer.java     |   8 +
 .../server/quorum/QuorumPeerConfig.java         |   5 +-
 .../server/quorum/UnifiedServerSocket.java      | 738 ++++++++++++++++++-
 .../apache/zookeeper/common/X509UtilTest.java   |  28 +
 .../zookeeper/server/quorum/QuorumSSLTest.java  |   2 -
 .../UnifiedServerSocketModeDetectionTest.java   | 404 ++++++++++
 .../server/quorum/UnifiedServerSocketTest.java  | 608 ++++++++++++---
 13 files changed, 1794 insertions(+), 157 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/zookeeper/blob/64104eae/zookeeper-server/src/main/java/org/apache/zookeeper/common/X509Util.java
----------------------------------------------------------------------
diff --git a/zookeeper-server/src/main/java/org/apache/zookeeper/common/X509Util.java b/zookeeper-server/src/main/java/org/apache/zookeeper/common/X509Util.java
index 5b97ac6..e3625a5 100644
--- a/zookeeper-server/src/main/java/org/apache/zookeeper/common/X509Util.java
+++ b/zookeeper-server/src/main/java/org/apache/zookeeper/common/X509Util.java
@@ -18,6 +18,7 @@
 package org.apache.zookeeper.common;
 
 
+import java.io.ByteArrayInputStream;
 import java.io.IOException;
 import java.net.Socket;
 import java.security.GeneralSecurityException;
@@ -74,6 +75,8 @@ public abstract class X509Util {
             "TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256"
     };
 
+    public static final int DEFAULT_HANDSHAKE_DETECTION_TIMEOUT_MILLIS = 5000;
+
     private String sslProtocolProperty = getConfigPrefix() + "protocol";
     private String cipherSuitesProperty = getConfigPrefix() + "ciphersuites";
     private String sslKeystoreLocationProperty = getConfigPrefix() + "keyStore.location";
@@ -85,6 +88,7 @@ public abstract class X509Util {
     private String sslHostnameVerificationEnabledProperty = getConfigPrefix() + "hostnameVerification";
     private String sslCrlEnabledProperty = getConfigPrefix() + "crl";
     private String sslOcspEnabledProperty = getConfigPrefix() + "ocsp";
+    private String sslHandshakeDetectionTimeoutMillisProperty = getConfigPrefix() + "handshakeDetectionTimeoutMillis";
 
     private String[] cipherSuites;
 
@@ -146,6 +150,16 @@ public abstract class X509Util {
         return sslOcspEnabledProperty;
     }
 
+    /**
+     * Returns the config property key that controls the amount of time, in milliseconds, that the first
+     * UnifiedServerSocket read operation will block for when trying to detect the client mode (TLS or PLAINTEXT).
+     *
+     * @return the config property key.
+     */
+    public String getSslHandshakeDetectionTimeoutMillisProperty() {
+        return sslHandshakeDetectionTimeoutMillisProperty;
+    }
+
     public SSLContext getDefaultSSLContext() throws X509Exception.SSLContextException {
         SSLContext result = defaultSSLContext.get();
         if (result == null) {
@@ -168,6 +182,31 @@ public abstract class X509Util {
         return createSSLContext(config);
     }
 
+    /**
+     * Returns the max amount of time, in milliseconds, that the first UnifiedServerSocket read() operation should
+     * block for when trying to detect the client mode (TLS or PLAINTEXT).
+     * Defaults to {@link X509Util#DEFAULT_HANDSHAKE_DETECTION_TIMEOUT_MILLIS}.
+     *
+     * @return the handshake detection timeout, in milliseconds.
+     */
+    public int getSslHandshakeTimeoutMillis() {
+        String propertyString = System.getProperty(getSslHandshakeDetectionTimeoutMillisProperty());
+        int result;
+        if (propertyString == null) {
+            result = DEFAULT_HANDSHAKE_DETECTION_TIMEOUT_MILLIS;
+        } else {
+            result = Integer.parseInt(propertyString);
+            if (result < 1) {
+                // Timeout of 0 is not allowed, since an infinite timeout can permanently lock up an
+                // accept() thread.
+                LOG.warn("Invalid value for " + getSslHandshakeDetectionTimeoutMillisProperty() + ": " + result +
+                        ", using the default value of " + DEFAULT_HANDSHAKE_DETECTION_TIMEOUT_MILLIS);
+                result = DEFAULT_HANDSHAKE_DETECTION_TIMEOUT_MILLIS;
+            }
+        }
+        return result;
+    }
+
     public SSLContext createSSLContext(ZKConfig config) throws SSLContextException {
         KeyManager[] keyManagers = null;
         TrustManager[] trustManagers = null;
@@ -350,14 +389,22 @@ public abstract class X509Util {
     public SSLSocket createSSLSocket() throws X509Exception, IOException {
         SSLSocket sslSocket = (SSLSocket) getDefaultSSLContext().getSocketFactory().createSocket();
         configureSSLSocket(sslSocket);
-
+        sslSocket.setUseClientMode(true);
         return sslSocket;
     }
 
-    public SSLSocket createSSLSocket(Socket socket) throws X509Exception, IOException {
-        SSLSocket sslSocket = (SSLSocket) getDefaultSSLContext().getSocketFactory().createSocket(socket, null, socket.getPort(), true);
+    public SSLSocket createSSLSocket(Socket socket, byte[] pushbackBytes) throws X509Exception, IOException {
+        SSLSocket sslSocket;
+        if (pushbackBytes != null && pushbackBytes.length > 0) {
+            sslSocket = (SSLSocket) getDefaultSSLContext().getSocketFactory().createSocket(
+                    socket, new ByteArrayInputStream(pushbackBytes), true);
+        } else {
+            sslSocket = (SSLSocket) getDefaultSSLContext().getSocketFactory().createSocket(
+                    socket, null, socket.getPort(), true);
+        }
         configureSSLSocket(sslSocket);
-
+        sslSocket.setUseClientMode(false);
+        sslSocket.setNeedClientAuth(true);
         return sslSocket;
     }
 

http://git-wip-us.apache.org/repos/asf/zookeeper/blob/64104eae/zookeeper-server/src/main/java/org/apache/zookeeper/common/ZKConfig.java
----------------------------------------------------------------------
diff --git a/zookeeper-server/src/main/java/org/apache/zookeeper/common/ZKConfig.java b/zookeeper-server/src/main/java/org/apache/zookeeper/common/ZKConfig.java
index 01bac69..effc0d5 100644
--- a/zookeeper-server/src/main/java/org/apache/zookeeper/common/ZKConfig.java
+++ b/zookeeper-server/src/main/java/org/apache/zookeeper/common/ZKConfig.java
@@ -130,6 +130,8 @@ public class ZKConfig {
                 System.getProperty(x509Util.getSslCrlEnabledProperty()));
         properties.put(x509Util.getSslOcspEnabledProperty(),
                 System.getProperty(x509Util.getSslOcspEnabledProperty()));
+        properties.put(x509Util.getSslHandshakeDetectionTimeoutMillisProperty(),
+                System.getProperty(x509Util.getSslHandshakeDetectionTimeoutMillisProperty()));
     }
 
     /**

http://git-wip-us.apache.org/repos/asf/zookeeper/blob/64104eae/zookeeper-server/src/main/java/org/apache/zookeeper/server/quorum/Leader.java
----------------------------------------------------------------------
diff --git a/zookeeper-server/src/main/java/org/apache/zookeeper/server/quorum/Leader.java b/zookeeper-server/src/main/java/org/apache/zookeeper/server/quorum/Leader.java
index 9270548..0a892b1 100644
--- a/zookeeper-server/src/main/java/org/apache/zookeeper/server/quorum/Leader.java
+++ b/zookeeper-server/src/main/java/org/apache/zookeeper/server/quorum/Leader.java
@@ -42,7 +42,6 @@ import java.util.concurrent.ConcurrentMap;
 import javax.security.sasl.SaslException;
 
 import org.apache.zookeeper.ZooDefs.OpCode;
-import org.apache.zookeeper.common.QuorumX509Util;
 import org.apache.zookeeper.common.Time;
 import org.apache.zookeeper.common.X509Exception;
 import org.apache.zookeeper.server.FinalRequestProcessor;
@@ -240,15 +239,15 @@ public class Leader {
         try {
             if (self.shouldUsePortUnification()) {
                 if (self.getQuorumListenOnAllIPs()) {
-                    ss = new UnifiedServerSocket(new QuorumX509Util(), self.getQuorumAddress().getPort());
+                    ss = new UnifiedServerSocket(self.getX509Util(), true, self.getQuorumAddress().getPort());
                 } else {
-                    ss = new UnifiedServerSocket(new QuorumX509Util());
+                    ss = new UnifiedServerSocket(self.getX509Util(), true);
                 }
             } else if (self.isSslQuorum()) {
                 if (self.getQuorumListenOnAllIPs()) {
-                    ss = new QuorumX509Util().createSSLServerSocket(self.getQuorumAddress().getPort());
+                    ss = self.getX509Util().createSSLServerSocket(self.getQuorumAddress().getPort());
                 } else {
-                    ss = new QuorumX509Util().createSSLServerSocket();
+                    ss = self.getX509Util().createSSLServerSocket();
                 }
             } else {
                 if (self.getQuorumListenOnAllIPs()) {
@@ -399,8 +398,10 @@ public class Leader {
         public void run() {
             try {
                 while (!stop) {
-                    try{
-                        Socket s = ss.accept();
+                    Socket s = null;
+                    boolean error = false;
+                    try {
+                        s = ss.accept();
 
                         // start with the initLimit, once the ack is processed
                         // in LearnerHandler switch to the syncLimit
@@ -412,6 +413,7 @@ public class Leader {
                         LearnerHandler fh = new LearnerHandler(s, is, Leader.this);
                         fh.start();
                     } catch (SocketException e) {
+                        error = true;
                         if (stop) {
                             LOG.info("exception while shutting down acceptor: "
                                     + e);
@@ -425,6 +427,19 @@ public class Leader {
                         }
                     } catch (SaslException e){
                         LOG.error("Exception while connecting to quorum learner", e);
+                        error = true;
+                    } catch (Exception e) {
+                        error = true;
+                        throw e;
+                    } finally {
+                        // Don't leak sockets on errors
+                        if (error && s != null && !s.isClosed()) {
+                            try {
+                                s.close();
+                            } catch (IOException e) {
+                                LOG.warn("Error closing socket", e);
+                            }
+                        }
                     }
                 }
             } catch (Exception e) {

http://git-wip-us.apache.org/repos/asf/zookeeper/blob/64104eae/zookeeper-server/src/main/java/org/apache/zookeeper/server/quorum/Learner.java
----------------------------------------------------------------------
diff --git a/zookeeper-server/src/main/java/org/apache/zookeeper/server/quorum/Learner.java b/zookeeper-server/src/main/java/org/apache/zookeeper/server/quorum/Learner.java
index c740d53..faaa844 100644
--- a/zookeeper-server/src/main/java/org/apache/zookeeper/server/quorum/Learner.java
+++ b/zookeeper-server/src/main/java/org/apache/zookeeper/server/quorum/Learner.java
@@ -38,9 +38,7 @@ import org.apache.jute.BinaryOutputArchive;
 import org.apache.jute.InputArchive;
 import org.apache.jute.OutputArchive;
 import org.apache.jute.Record;
-import org.apache.zookeeper.common.QuorumX509Util;
 import org.apache.zookeeper.common.X509Exception;
-import org.apache.zookeeper.common.X509Util;
 import org.apache.zookeeper.server.ExitCode;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
@@ -74,8 +72,6 @@ public class Learner {
     
     protected Socket sock;
 
-    protected X509Util x509Util;
-    
     /**
      * Socket getter
      * @return 
@@ -304,10 +300,7 @@ public class Learner {
     private Socket createSocket() throws X509Exception, IOException {
         Socket sock;
         if (self.isSslQuorum()) {
-            if (x509Util == null) {
-                x509Util = new QuorumX509Util();
-            }
-            sock = x509Util.createSSLSocket();
+            sock = self.getX509Util().createSSLSocket();
         } else {
             sock = new Socket();
         }

http://git-wip-us.apache.org/repos/asf/zookeeper/blob/64104eae/zookeeper-server/src/main/java/org/apache/zookeeper/server/quorum/PrependableSocket.java
----------------------------------------------------------------------
diff --git a/zookeeper-server/src/main/java/org/apache/zookeeper/server/quorum/PrependableSocket.java b/zookeeper-server/src/main/java/org/apache/zookeeper/server/quorum/PrependableSocket.java
index a86608f..94a526e 100644
--- a/zookeeper-server/src/main/java/org/apache/zookeeper/server/quorum/PrependableSocket.java
+++ b/zookeeper-server/src/main/java/org/apache/zookeeper/server/quorum/PrependableSocket.java
@@ -18,16 +18,15 @@
 
 package org.apache.zookeeper.server.quorum;
 
-import java.io.ByteArrayInputStream;
 import java.io.IOException;
 import java.io.InputStream;
-import java.io.SequenceInputStream;
+import java.io.PushbackInputStream;
 import java.net.Socket;
 import java.net.SocketImpl;
 
 public class PrependableSocket extends Socket {
 
-  private SequenceInputStream sequenceInputStream;
+  private PushbackInputStream pushbackInputStream;
 
   public PrependableSocket(SocketImpl base) throws IOException {
     super(base);
@@ -35,15 +34,31 @@ public class PrependableSocket extends Socket {
 
   @Override
   public InputStream getInputStream() throws IOException {
-    if (sequenceInputStream == null) {
+    if (pushbackInputStream == null) {
       return super.getInputStream();
     }
 
-    return sequenceInputStream;
+    return pushbackInputStream;
   }
 
-  public void prependToInputStream(byte[] bytes) throws IOException {
-    sequenceInputStream = new SequenceInputStream(new ByteArrayInputStream(bytes), getInputStream());
+  /**
+   * Prepend some bytes that have already been read back to the socket's input stream. Note that this method can be
+   * called at most once with a non-0 length per socket instance.
+   * @param bytes the bytes to prepend.
+   * @param offset offset in the byte array to start at.
+   * @param length number of bytes to prepend.
+   * @throws IOException if this method was already called on the socket instance, or if super.getInputStream() throws.
+   */
+  public void prependToInputStream(byte[] bytes, int offset, int length) throws IOException {
+    if (length == 0) {
+      return; // nothing to prepend
+    }
+    if (pushbackInputStream != null) {
+      throw new IOException("prependToInputStream() called more than once");
+    }
+    PushbackInputStream pushbackInputStream = new PushbackInputStream(getInputStream(), length);
+    pushbackInputStream.unread(bytes, offset, length);
+    this.pushbackInputStream = pushbackInputStream;
   }
 
 }
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/zookeeper/blob/64104eae/zookeeper-server/src/main/java/org/apache/zookeeper/server/quorum/QuorumCnxManager.java
----------------------------------------------------------------------
diff --git a/zookeeper-server/src/main/java/org/apache/zookeeper/server/quorum/QuorumCnxManager.java b/zookeeper-server/src/main/java/org/apache/zookeeper/server/quorum/QuorumCnxManager.java
index 8b91023..4175f3c 100644
--- a/zookeeper-server/src/main/java/org/apache/zookeeper/server/quorum/QuorumCnxManager.java
+++ b/zookeeper-server/src/main/java/org/apache/zookeeper/server/quorum/QuorumCnxManager.java
@@ -47,9 +47,7 @@ import java.util.NoSuchElementException;
 import java.util.concurrent.atomic.AtomicInteger;
 import java.util.concurrent.atomic.AtomicLong;
 
-import org.apache.zookeeper.common.QuorumX509Util;
 import org.apache.zookeeper.common.X509Exception;
-import org.apache.zookeeper.common.X509Util;
 import org.apache.zookeeper.server.ExitCode;
 import org.apache.zookeeper.server.quorum.QuorumPeerConfig.ConfigException;
 import org.apache.zookeeper.server.util.ConfigUtils;
@@ -175,9 +173,6 @@ public class QuorumCnxManager {
      */
     private final boolean tcpKeepAlive = Boolean.getBoolean("zookeeper.tcpKeepAlive");
 
-
-    private X509Util x509Util;
-
     static public class Message {
         Message(ByteBuffer buffer, long sid) {
             this.buffer = buffer;
@@ -291,8 +286,6 @@ public class QuorumCnxManager {
         // Starts listener thread that waits for connection requests
         listener = new Listener();
         listener.setName("QuorumPeerListener");
-
-        x509Util = new QuorumX509Util();
     }
 
     private void initializeAuth(final long mySid,
@@ -655,17 +648,18 @@ public class QuorumCnxManager {
         try {
             LOG.debug("Opening channel to server " + sid);
             if (self.isSslQuorum()) {
-                SSLSocket sslSock = x509Util.createSSLSocket();
-                setSockOpts(sslSock);
-                sslSock.connect(electionAddr, cnxTO);
-                sslSock.startHandshake();
-                sock = sslSock;
-            } else {
-                sock = new Socket();
-                setSockOpts(sock);
-                sock.connect(electionAddr, cnxTO);
-            }
-            LOG.debug("Connected to server " + sid);
+                 SSLSocket sslSock = self.getX509Util().createSSLSocket();
+                 setSockOpts(sslSock);
+                 sslSock.connect(electionAddr, cnxTO);
+                 sslSock.startHandshake();
+                 sock = sslSock;
+             } else {
+                 sock = new Socket();
+                 setSockOpts(sock);
+                 sock.connect(electionAddr, cnxTO);
+
+             }
+             LOG.debug("Connected to server " + sid);
             // Sends connection request asynchronously if the quorum
             // sasl authentication is enabled. This is required because
             // sasl server authentication process may take few seconds to
@@ -876,9 +870,9 @@ public class QuorumCnxManager {
             while((!shutdown) && (numRetries < 3)){
                 try {
                     if (self.shouldUsePortUnification()) {
-                        ss = new UnifiedServerSocket(x509Util);
+                        ss = new UnifiedServerSocket(self.getX509Util(), true);
                     } else if (self.isSslQuorum()) {
-                        ss = x509Util.createSSLServerSocket();
+                        ss = self.getX509Util().createSSLServerSocket();
                     } else {
                         ss = new ServerSocket();
                     }

http://git-wip-us.apache.org/repos/asf/zookeeper/blob/64104eae/zookeeper-server/src/main/java/org/apache/zookeeper/server/quorum/QuorumPeer.java
----------------------------------------------------------------------
diff --git a/zookeeper-server/src/main/java/org/apache/zookeeper/server/quorum/QuorumPeer.java b/zookeeper-server/src/main/java/org/apache/zookeeper/server/quorum/QuorumPeer.java
index 136a538..7abde4b 100644
--- a/zookeeper-server/src/main/java/org/apache/zookeeper/server/quorum/QuorumPeer.java
+++ b/zookeeper-server/src/main/java/org/apache/zookeeper/server/quorum/QuorumPeer.java
@@ -47,6 +47,7 @@ import javax.security.sasl.SaslException;
 import org.apache.zookeeper.KeeperException.BadArgumentsException;
 import org.apache.zookeeper.common.AtomicFileWritingIdiom;
 import org.apache.zookeeper.common.AtomicFileWritingIdiom.WriterStatement;
+import org.apache.zookeeper.common.QuorumX509Util;
 import org.apache.zookeeper.common.Time;
 import org.apache.zookeeper.common.X509Exception;
 import org.apache.zookeeper.jmx.MBeanRegistry;
@@ -479,6 +480,12 @@ public class QuorumPeer extends ZooKeeperThread implements QuorumStats.Provider
         return shouldUsePortUnification;
     }
 
+    private final QuorumX509Util x509Util;
+
+    QuorumX509Util getX509Util() {
+        return x509Util;
+    }
+
     /**
      * This is who I think the leader currently is.
      */
@@ -801,6 +808,7 @@ public class QuorumPeer extends ZooKeeperThread implements QuorumStats.Provider
         quorumStats = new QuorumStats(this);
         jmxRemotePeerBean = new HashMap<Long, RemotePeerBean>();
         adminServer = AdminServerFactory.createAdminServer();
+        x509Util = new QuorumX509Util();
         initialize();
     }
 

http://git-wip-us.apache.org/repos/asf/zookeeper/blob/64104eae/zookeeper-server/src/main/java/org/apache/zookeeper/server/quorum/QuorumPeerConfig.java
----------------------------------------------------------------------
diff --git a/zookeeper-server/src/main/java/org/apache/zookeeper/server/quorum/QuorumPeerConfig.java b/zookeeper-server/src/main/java/org/apache/zookeeper/server/quorum/QuorumPeerConfig.java
index 45463b1..aee5efc 100644
--- a/zookeeper-server/src/main/java/org/apache/zookeeper/server/quorum/QuorumPeerConfig.java
+++ b/zookeeper-server/src/main/java/org/apache/zookeeper/server/quorum/QuorumPeerConfig.java
@@ -315,9 +315,8 @@ public class QuorumPeerConfig {
                 }
             } else if (key.equals("sslQuorum")){
                 sslQuorum = Boolean.parseBoolean(value);
-// TODO: UnifiedServerSocket is currently buggy, will be fixed when @ivmaykov's PRs are merged. Disable port unification until then.
-//            } else if (key.equals("portUnification")){
-//                shouldUsePortUnification = Boolean.parseBoolean(value);
+            } else if (key.equals("portUnification")){
+                shouldUsePortUnification = Boolean.parseBoolean(value);
             } else if ((key.startsWith("server.") || key.startsWith("group") || key.startsWith("weight")) && zkProp.containsKey("dynamicConfigFile")) {
                 throw new ConfigException("parameter: " + key + " must be in a separate dynamic config file");
             } else if (key.equals(QuorumAuth.QUORUM_SASL_AUTH_ENABLED)) {

http://git-wip-us.apache.org/repos/asf/zookeeper/blob/64104eae/zookeeper-server/src/main/java/org/apache/zookeeper/server/quorum/UnifiedServerSocket.java
----------------------------------------------------------------------
diff --git a/zookeeper-server/src/main/java/org/apache/zookeeper/server/quorum/UnifiedServerSocket.java b/zookeeper-server/src/main/java/org/apache/zookeeper/server/quorum/UnifiedServerSocket.java
index d1e3ba5..bbe245f 100644
--- a/zookeeper-server/src/main/java/org/apache/zookeeper/server/quorum/UnifiedServerSocket.java
+++ b/zookeeper-server/src/main/java/org/apache/zookeeper/server/quorum/UnifiedServerSocket.java
@@ -27,23 +27,111 @@ import org.slf4j.LoggerFactory;
 
 import javax.net.ssl.SSLSocket;
 import java.io.IOException;
+import java.io.InputStream;
+import java.io.OutputStream;
+import java.net.InetAddress;
 import java.net.ServerSocket;
 import java.net.Socket;
+import java.net.SocketAddress;
 import java.net.SocketException;
+import java.net.SocketTimeoutException;
+import java.nio.channels.SocketChannel;
 
+/**
+ * A ServerSocket that can act either as a regular ServerSocket, as a SSLServerSocket, or as both, depending on
+ * the constructor parameters and on the type of client (TLS or plaintext) that connects to it.
+ * The constructors have the same signature as constructors of ServerSocket, with the addition of two parameters
+ * at the beginning:
+ * <ul>
+ *     <li>X509Util - provides the SSL context to construct a secure socket when a client connects with TLS.</li>
+ *     <li>boolean allowInsecureConnection - when true, acts as a hybrid server socket (plaintext / TLS). When
+ *         false, acts as a SSLServerSocket (rejects plaintext connections).</li>
+ * </ul>
+ * The <code>!allowInsecureConnection</code> mode is needed so we can update the SSLContext (in particular, the
+ * key store and/or trust store) without having to re-create the server socket. By starting with a plaintext socket
+ * and delaying the upgrade to TLS until after a client has connected and begins a handshake, we can keep the same
+ * UnifiedServerSocket instance around, and replace the default SSLContext in the provided X509Util when the key store
+ * and/or trust store file changes on disk.
+ */
 public class UnifiedServerSocket extends ServerSocket {
     private static final Logger LOG = LoggerFactory.getLogger(UnifiedServerSocket.class);
 
     private X509Util x509Util;
+    private final boolean allowInsecureConnection;
 
-    public UnifiedServerSocket(X509Util x509Util) throws IOException {
+    /**
+     * Creates an unbound unified server socket by calling {@link ServerSocket#ServerSocket()}.
+     * Secure client connections will be upgraded to TLS once this socket detects the ClientHello message (start of a
+     * TLS handshake). Plaintext client connections will either be accepted or rejected depending on the value of
+     * the <code>allowInsecureConnection</code> parameter.
+     * @param x509Util the X509Util that provides the SSLContext to use for secure connections.
+     * @param allowInsecureConnection if true, accept plaintext connections, otherwise close them.
+     * @throws IOException if {@link ServerSocket#ServerSocket()} throws.
+     */
+    public UnifiedServerSocket(X509Util x509Util, boolean allowInsecureConnection) throws IOException {
         super();
         this.x509Util = x509Util;
+        this.allowInsecureConnection = allowInsecureConnection;
     }
 
-    public UnifiedServerSocket(X509Util x509Util, int port) throws IOException {
+    /**
+     * Creates a unified server socket bound to the specified port by calling {@link ServerSocket#ServerSocket(int)}.
+     * Secure client connections will be upgraded to TLS once this socket detects the ClientHello message (start of a
+     * TLS handshake). Plaintext client connections will either be accepted or rejected depending on the value of
+     * the <code>allowInsecureConnection</code> parameter.
+     * @param x509Util the X509Util that provides the SSLContext to use for secure connections.
+     * @param allowInsecureConnection if true, accept plaintext connections, otherwise close them.
+     * @param port the port number, or {@code 0} to use a port number that is automatically allocated.
+     * @throws IOException if {@link ServerSocket#ServerSocket(int)} throws.
+     */
+    public UnifiedServerSocket(X509Util x509Util, boolean allowInsecureConnection, int port) throws IOException {
         super(port);
         this.x509Util = x509Util;
+        this.allowInsecureConnection = allowInsecureConnection;
+    }
+
+    /**
+     * Creates a unified server socket bound to the specified port, with the specified backlog, by calling
+     * {@link ServerSocket#ServerSocket(int, int)}.
+     * Secure client connections will be upgraded to TLS once this socket detects the ClientHello message (start of a
+     * TLS handshake). Plaintext client connections will either be accepted or rejected depending on the value of
+     * the <code>allowInsecureConnection</code> parameter.
+     * @param x509Util the X509Util that provides the SSLContext to use for secure connections.
+     * @param allowInsecureConnection if true, accept plaintext connections, otherwise close them.
+     * @param port the port number, or {@code 0} to use a port number that is automatically allocated.
+     * @param backlog requested maximum length of the queue of incoming connections.
+     * @throws IOException if {@link ServerSocket#ServerSocket(int, int)} throws.
+     */
+    public UnifiedServerSocket(X509Util x509Util,
+                               boolean allowInsecureConnection,
+                               int port,
+                               int backlog) throws IOException {
+        super(port, backlog);
+        this.x509Util = x509Util;
+        this.allowInsecureConnection = allowInsecureConnection;
+    }
+
+    /**
+     * Creates a unified server socket bound to the specified port, with the specified backlog, and local IP address
+     * to bind to, by calling {@link ServerSocket#ServerSocket(int, int, InetAddress)}.
+     * Secure client connections will be upgraded to TLS once this socket detects the ClientHello message (start of a
+     * TLS handshake). Plaintext client connections will either be accepted or rejected depending on the value of
+     * the <code>allowInsecureConnection</code> parameter.
+     * @param x509Util the X509Util that provides the SSLContext to use for secure connections.
+     * @param allowInsecureConnection if true, accept plaintext connections, otherwise close them.
+     * @param port the port number, or {@code 0} to use a port number that is automatically allocated.
+     * @param backlog requested maximum length of the queue of incoming connections.
+     * @param bindAddr the local InetAddress the server will bind to.
+     * @throws IOException if {@link ServerSocket#ServerSocket(int, int, InetAddress)} throws.
+     */
+    public UnifiedServerSocket(X509Util x509Util,
+                               boolean allowInsecureConnection,
+                               int port,
+                               int backlog,
+                               InetAddress bindAddr) throws IOException {
+        super(port, backlog, bindAddr);
+        this.x509Util = x509Util;
+        this.allowInsecureConnection = allowInsecureConnection;
     }
 
     @Override
@@ -56,24 +144,642 @@ public class UnifiedServerSocket extends ServerSocket {
         }
         final PrependableSocket prependableSocket = new PrependableSocket(null);
         implAccept(prependableSocket);
+        return new UnifiedSocket(x509Util, allowInsecureConnection, prependableSocket);
+    }
+
+    /**
+     * The result of calling accept() on a UnifiedServerSocket. This is a Socket that doesn't know if it's
+     * using plaintext or SSL/TLS at the time when it is created. Calling a method that indicates a desire to
+     * read or write from the socket will cause the socket to detect if the connected client is attempting
+     * to establish a TLS or plaintext connection. This is done by doing a blocking read of 5 bytes off the
+     * socket and checking if the bytes look like the start of a TLS ClientHello message. If it looks like
+     * the client is attempting to connect with TLS, the internal socket is upgraded to a SSLSocket. If not,
+     * any bytes read from the socket are pushed back to the input stream, and the socket continues
+     * to be treated as a plaintext socket.
+     *
+     * The methods that trigger this behavior are:
+     * <ul>
+     *     <li>{@link UnifiedSocket#getInputStream()}</li>
+     *     <li>{@link UnifiedSocket#getOutputStream()}</li>
+     *     <li>{@link UnifiedSocket#sendUrgentData(int)}</li>
+     * </ul>
+     *
+     * Calling other socket methods (i.e option setters such as {@link Socket#setTcpNoDelay(boolean)}) does
+     * not trigger mode detection.
+     *
+     * Because detecting the mode is a potentially blocking operation, it should not be done in the
+     * accepting thread. Attempting to read from or write to the socket in the accepting thread opens the
+     * caller up to a denial-of-service attack, in which a client connects and then does nothing. This would
+     * prevent any other clients from connecting. Passing the socket returned by accept() to a separate
+     * thread which handles all read and write operations protects against this DoS attack.
+     *
+     * Callers can check if the socket has been upgraded to TLS by calling {@link UnifiedSocket#isSecureSocket()},
+     * and can get the underlying SSLSocket by calling {@link UnifiedSocket#getSslSocket()}.
+     */
+    public static class UnifiedSocket extends Socket {
+        private enum Mode {
+            UNKNOWN,
+            PLAINTEXT,
+            TLS
+        }
 
-        byte[] litmus = new byte[5];
-        int bytesRead = prependableSocket.getInputStream().read(litmus, 0, 5);
-        prependableSocket.prependToInputStream(litmus);
+        private final X509Util x509Util;
+        private final boolean allowInsecureConnection;
+        private PrependableSocket prependableSocket;
+        private SSLSocket sslSocket;
+        private Mode mode;
 
-        if (bytesRead == 5 && SslHandler.isEncrypted(Unpooled.wrappedBuffer(litmus))) {
-            LOG.info(getInetAddress() + " attempting to connect over ssl");
-            SSLSocket sslSocket;
+        /**
+         * Note: this constructor is intentionally private. The only intended caller is
+         * {@link UnifiedServerSocket#accept()}.
+         *
+         * @param x509Util
+         * @param allowInsecureConnection
+         * @param prependableSocket
+         */
+        private UnifiedSocket(X509Util x509Util, boolean allowInsecureConnection, PrependableSocket prependableSocket) {
+            this.x509Util = x509Util;
+            this.allowInsecureConnection = allowInsecureConnection;
+            this.prependableSocket = prependableSocket;
+            this.sslSocket = null;
+            this.mode = Mode.UNKNOWN;
+        }
+
+        /**
+         * Returns true if the socket mode has been determined to be TLS.
+         * @return true if the mode is TLS, false if it is UNKNOWN or PLAINTEXT.
+         */
+        public boolean isSecureSocket() {
+            return mode == Mode.TLS;
+        }
+
+        /**
+         * Returns true if the socket mode has been determined to be PLAINTEXT.
+         * @return true if the mode is PLAINTEXT, false if it is UNKNOWN or TLS.
+         */
+        public boolean isPlaintextSocket() {
+            return mode == Mode.PLAINTEXT;
+        }
+
+        /**
+         * Returns true if the socket mode is not yet known.
+         * @return true if the mode is UNKNOWN, false if it is PLAINTEXT or TLS.
+         */
+        public boolean isModeKnown() {
+            return mode != Mode.UNKNOWN;
+        }
+
+        /**
+         * Detects the socket mode, see comments at the top of the class for more details. This operation will block
+         * for up to {@link X509Util#getSslHandshakeTimeoutMillis()} milliseconds and should not be called in the
+         * accept() thread if possible.
+         * @throws IOException
+         */
+        private void detectMode() throws IOException {
+            byte[] litmus = new byte[5];
+            int oldTimeout = -1;
+            int bytesRead = 0;
+            int newTimeout = x509Util.getSslHandshakeTimeoutMillis();
             try {
-                sslSocket = x509Util.createSSLSocket(prependableSocket);
-            } catch (X509Exception e) {
-                throw new IOException("failed to create SSL context", e);
+                oldTimeout = prependableSocket.getSoTimeout();
+                prependableSocket.setSoTimeout(newTimeout);
+                bytesRead = prependableSocket.getInputStream().read(litmus, 0, litmus.length);
+            } catch (SocketTimeoutException e) {
+                // Didn't read anything within the timeout, fallthrough and assume the connection is plaintext.
+                LOG.warn("Socket mode detection timed out after " + newTimeout + " ms, assuming PLAINTEXT");
+            } finally {
+                // restore socket timeout to the old value
+                try {
+                    if (oldTimeout != -1) {
+                        prependableSocket.setSoTimeout(oldTimeout);
+                    }
+                } catch (Exception e) {
+                    LOG.warn("Failed to restore old socket timeout value of " + oldTimeout + " ms", e);
+                }
+            }
+            if (bytesRead < 0) { // Got a EOF right away, definitely not using TLS. Fallthrough.
+                bytesRead = 0;
+            }
+
+            if (bytesRead == litmus.length && SslHandler.isEncrypted(Unpooled.wrappedBuffer(litmus))) {
+                try {
+                    sslSocket = x509Util.createSSLSocket(prependableSocket, litmus);
+                } catch (X509Exception e) {
+                    throw new IOException("failed to create SSL context", e);
+                }
+                prependableSocket = null;
+                mode = Mode.TLS;
+            } else if (allowInsecureConnection) {
+                prependableSocket.prependToInputStream(litmus, 0, bytesRead);
+                mode = Mode.PLAINTEXT;
+            } else {
+                prependableSocket.close();
+                mode = Mode.PLAINTEXT;
+                throw new IOException("Blocked insecure connection attempt");
+            }
+        }
+
+        private Socket getSocketAllowUnknownMode() {
+            if (isSecureSocket()) {
+                return sslSocket;
+            } else { // Note: mode is UNKNOWN or PLAINTEXT
+                return prependableSocket;
+            }
+        }
+
+        /**
+         * Returns the underlying socket, detecting the socket mode if it is not yet known. This is a potentially
+         * blocking operation and should not be called in the accept() thread.
+         * @return the underlying socket, after the socket mode has been determined.
+         * @throws IOException
+         */
+        private Socket getSocket() throws IOException {
+            if (!isModeKnown()) {
+                detectMode();
+            }
+            if (mode == Mode.TLS) {
+                return sslSocket;
+            } else {
+                return prependableSocket;
+            }
+        }
+
+        /**
+         * Returns the underlying SSLSocket if the mode is TLS. If the mode is UNKNOWN, causes mode detection which is a
+         * potentially blocking operation. If the mode ends up being PLAINTEXT, this will throw a SocketException, so
+         * callers are advised to only call this method after checking that {@link UnifiedSocket#isSecureSocket()}
+         * returned true.
+         * @return the underlying SSLSocket if the mode is known to be TLS.
+         * @throws IOException if detecting the socket mode fails
+         * @throws SocketException if the mode is PLAINTEXT.
+         */
+        public SSLSocket getSslSocket() throws IOException {
+            if (!isModeKnown()) {
+                detectMode();
+            }
+            if (!isSecureSocket()) {
+                throw new SocketException("Socket mode is not TLS");
             }
-            sslSocket.setUseClientMode(false);
             return sslSocket;
-        } else {
-            LOG.info(getInetAddress() + " attempting to connect without ssl");
-            return prependableSocket;
         }
+
+        /**
+         * See {@link Socket#connect(SocketAddress)}. Calling this method does not trigger mode detection.
+         */
+        @Override
+        public void connect(SocketAddress endpoint) throws IOException {
+            getSocketAllowUnknownMode().connect(endpoint);
+        }
+
+        /**
+         * See {@link Socket#connect(SocketAddress, int)}. Calling this method does not trigger mode detection.
+         */
+        @Override
+        public void connect(SocketAddress endpoint, int timeout) throws IOException {
+            getSocketAllowUnknownMode().connect(endpoint, timeout);
+        }
+
+        /**
+         * See {@link Socket#bind(SocketAddress)}. Calling this method does not trigger mode detection.
+         */
+        @Override
+        public void bind(SocketAddress bindpoint) throws IOException {
+            getSocketAllowUnknownMode().bind(bindpoint);
+        }
+
+        /**
+         * See {@link Socket#getInetAddress()}. Calling this method does not trigger mode detection.
+         */
+        @Override
+        public InetAddress getInetAddress() {
+            return getSocketAllowUnknownMode().getInetAddress();
+        }
+
+        /**
+         * See {@link Socket#getLocalAddress()}. Calling this method does not trigger mode detection.
+         */
+        @Override
+        public InetAddress getLocalAddress() {
+            return getSocketAllowUnknownMode().getLocalAddress();
+        }
+
+        /**
+         * See {@link Socket#getPort()}. Calling this method does not trigger mode detection.
+         */
+        @Override
+        public int getPort() {
+            return getSocketAllowUnknownMode().getPort();
+        }
+
+        /**
+         * See {@link Socket#getLocalPort()}. Calling this method does not trigger mode detection.
+         */
+        @Override
+        public int getLocalPort() {
+            return getSocketAllowUnknownMode().getLocalPort();
+        }
+
+        /**
+         * See {@link Socket#getRemoteSocketAddress()}. Calling this method does not trigger mode detection.
+         */
+        @Override
+        public SocketAddress getRemoteSocketAddress() {
+            return getSocketAllowUnknownMode().getRemoteSocketAddress();
+        }
+
+        /**
+         * See {@link Socket#getLocalSocketAddress()}. Calling this method does not trigger mode detection.
+         */
+        @Override
+        public SocketAddress getLocalSocketAddress() {
+            return getSocketAllowUnknownMode().getLocalSocketAddress();
+        }
+
+        /**
+         * See {@link Socket#getChannel()}. Calling this method does not trigger mode detection.
+         */
+        @Override
+        public SocketChannel getChannel() {
+            return getSocketAllowUnknownMode().getChannel();
+        }
+
+        /**
+         * See {@link Socket#getInputStream()}. If the socket mode has not yet been detected, the first read from the
+         * returned input stream will trigger mode detection, which is a potentially blocking operation. This means
+         * the accept() thread should avoid reading from this input stream if possible.
+         */
+        @Override
+        public InputStream getInputStream() throws IOException {
+            return new UnifiedInputStream(this);
+        }
+
+        /**
+         * See {@link Socket#getOutputStream()}. If the socket mode has not yet been detected, the first read from the
+         * returned input stream will trigger mode detection, which is a potentially blocking operation. This means
+         * the accept() thread should avoid reading from this input stream if possible.
+         */
+        @Override
+        public OutputStream getOutputStream() throws IOException {
+            return new UnifiedOutputStream(this);
+        }
+
+        /**
+         * See {@link Socket#setTcpNoDelay(boolean)}. Calling this method does not trigger mode detection.
+         */
+        @Override
+        public void setTcpNoDelay(boolean on) throws SocketException {
+            getSocketAllowUnknownMode().setTcpNoDelay(on);
+        }
+
+        /**
+         * See {@link Socket#getTcpNoDelay()}. Calling this method does not trigger mode detection.
+         */
+        @Override
+        public boolean getTcpNoDelay() throws SocketException {
+            return getSocketAllowUnknownMode().getTcpNoDelay();
+        }
+
+        /**
+         * See {@link Socket#setSoLinger(boolean, int)}. Calling this method does not trigger mode detection.
+         */
+        @Override
+        public void setSoLinger(boolean on, int linger) throws SocketException {
+            getSocketAllowUnknownMode().setSoLinger(on, linger);
+        }
+
+        /**
+         * See {@link Socket#getSoLinger()}. Calling this method does not trigger mode detection.
+         */
+        @Override
+        public int getSoLinger() throws SocketException {
+            return getSocketAllowUnknownMode().getSoLinger();
+        }
+
+        /**
+         * See {@link Socket#sendUrgentData(int)}. Calling this method triggers mode detection, which is a potentially
+         * blocking operation, so it should not be done in the accept() thread.
+         */
+        @Override
+        public void sendUrgentData(int data) throws IOException {
+            getSocket().sendUrgentData(data);
+        }
+
+        /**
+         * See {@link Socket#setOOBInline(boolean)}. Calling this method does not trigger mode detection.
+         */
+        @Override
+        public void setOOBInline(boolean on) throws SocketException {
+            getSocketAllowUnknownMode().setOOBInline(on);
+        }
+
+        /**
+         * See {@link Socket#getOOBInline()}. Calling this method does not trigger mode detection.
+         */
+        @Override
+        public boolean getOOBInline() throws SocketException {
+            return getSocketAllowUnknownMode().getOOBInline();
+        }
+
+        /**
+         * See {@link Socket#setSoTimeout(int)}. Calling this method does not trigger mode detection.
+         */
+        @Override
+        public synchronized void setSoTimeout(int timeout) throws SocketException {
+            getSocketAllowUnknownMode().setSoTimeout(timeout);
+        }
+
+        /**
+         * See {@link Socket#getSoTimeout()}. Calling this method does not trigger mode detection.
+         */
+        @Override
+        public synchronized int getSoTimeout() throws SocketException {
+            return getSocketAllowUnknownMode().getSoTimeout();
+        }
+
+        /**
+         * See {@link Socket#setSendBufferSize(int)}. Calling this method does not trigger mode detection.
+         */
+        @Override
+        public synchronized void setSendBufferSize(int size) throws SocketException {
+            getSocketAllowUnknownMode().setSendBufferSize(size);
+        }
+
+        /**
+         * See {@link Socket#getSendBufferSize()}. Calling this method does not trigger mode detection.
+         */
+        @Override
+        public synchronized int getSendBufferSize() throws SocketException {
+            return getSocketAllowUnknownMode().getSendBufferSize();
+        }
+
+        /**
+         * See {@link Socket#setReceiveBufferSize(int)}. Calling this method does not trigger mode detection.
+         */
+        @Override
+        public synchronized void setReceiveBufferSize(int size) throws SocketException {
+            getSocketAllowUnknownMode().setReceiveBufferSize(size);
+        }
+
+        /**
+         * See {@link Socket#getReceiveBufferSize()}. Calling this method does not trigger mode detection.
+         */
+        @Override
+        public synchronized int getReceiveBufferSize() throws SocketException {
+            return getSocketAllowUnknownMode().getReceiveBufferSize();
+        }
+
+        /**
+         * See {@link Socket#setKeepAlive(boolean)}. Calling this method does not trigger mode detection.
+         */
+        @Override
+        public void setKeepAlive(boolean on) throws SocketException {
+            getSocketAllowUnknownMode().setKeepAlive(on);
+        }
+
+        /**
+         * See {@link Socket#getKeepAlive()}. Calling this method does not trigger mode detection.
+         */
+        @Override
+        public boolean getKeepAlive() throws SocketException {
+            return getSocketAllowUnknownMode().getKeepAlive();
+        }
+
+        /**
+         * See {@link Socket#setTrafficClass(int)}. Calling this method does not trigger mode detection.
+         */
+        @Override
+        public void setTrafficClass(int tc) throws SocketException {
+            getSocketAllowUnknownMode().setTrafficClass(tc);
+        }
+
+        /**
+         * See {@link Socket#getTrafficClass()}. Calling this method does not trigger mode detection.
+         */
+        @Override
+        public int getTrafficClass() throws SocketException {
+            return getSocketAllowUnknownMode().getTrafficClass();
+        }
+
+        /**
+         * See {@link Socket#setReuseAddress(boolean)}. Calling this method does not trigger mode detection.
+         */
+        @Override
+        public void setReuseAddress(boolean on) throws SocketException {
+            getSocketAllowUnknownMode().setReuseAddress(on);
+        }
+
+        /**
+         * See {@link Socket#getReuseAddress()}. Calling this method does not trigger mode detection.
+         */
+        @Override
+        public boolean getReuseAddress() throws SocketException {
+            return getSocketAllowUnknownMode().getReuseAddress();
+        }
+
+        /**
+         * See {@link Socket#close()}. Calling this method does not trigger mode detection.
+         */
+        @Override
+        public synchronized void close() throws IOException {
+            getSocketAllowUnknownMode().close();
+        }
+
+        /**
+         * See {@link Socket#shutdownInput()}. Calling this method does not trigger mode detection.
+         */
+        @Override
+        public void shutdownInput() throws IOException {
+            getSocketAllowUnknownMode().shutdownInput();
+        }
+
+        /**
+         * See {@link Socket#shutdownOutput()}. Calling this method does not trigger mode detection.
+         */
+        @Override
+        public void shutdownOutput() throws IOException {
+            getSocketAllowUnknownMode().shutdownOutput();
+        }
+
+        /**
+         * See {@link Socket#toString()}. Calling this method does not trigger mode detection.
+         */
+        @Override
+        public String toString() {
+            return "UnifiedSocket[mode=" + mode.toString() + "socket=" + getSocketAllowUnknownMode().toString() + "]";
+        }
+
+        /**
+         * See {@link Socket#isConnected()}. Calling this method does not trigger mode detection.
+         */
+        @Override
+        public boolean isConnected() {
+            return getSocketAllowUnknownMode().isConnected();
+        }
+
+        /**
+         * See {@link Socket#isBound()}. Calling this method does not trigger mode detection.
+         */
+        @Override
+        public boolean isBound() {
+            return getSocketAllowUnknownMode().isBound();
+        }
+
+        /**
+         * See {@link Socket#isClosed()}. Calling this method does not trigger mode detection.
+         */
+        @Override
+        public boolean isClosed() {
+            return getSocketAllowUnknownMode().isClosed();
+        }
+
+        /**
+         * See {@link Socket#isInputShutdown()}. Calling this method does not trigger mode detection.
+         */
+        @Override
+        public boolean isInputShutdown() {
+            return getSocketAllowUnknownMode().isInputShutdown();
+        }
+
+        /**
+         * See {@link Socket#isOutputShutdown()}. Calling this method does not trigger mode detection.
+         */
+        @Override
+        public boolean isOutputShutdown() {
+            return getSocketAllowUnknownMode().isOutputShutdown();
+        }
+
+        /**
+         * See {@link Socket#setPerformancePreferences(int, int, int)}. Calling this method does not trigger
+         * mode detection.
+         */
+        @Override
+        public void setPerformancePreferences(int connectionTime, int latency, int bandwidth) {
+            getSocketAllowUnknownMode().setPerformancePreferences(connectionTime, latency, bandwidth);
+        }
+    }
+
+    /**
+     * An input stream for a UnifiedSocket. The first read from this stream will trigger mode detection on the
+     * underlying UnifiedSocket.
+     */
+    private static class UnifiedInputStream extends InputStream {
+        private final UnifiedSocket unifiedSocket;
+        private InputStream realInputStream;
+
+        private UnifiedInputStream(UnifiedSocket unifiedSocket) {
+            this.unifiedSocket = unifiedSocket;
+            this.realInputStream = null;
+        }
+
+        @Override
+        public int read() throws IOException {
+            return getRealInputStream().read();
+        }
+
+        /**
+         * Note: SocketInputStream has optimized implementations of bulk-read operations, so we need to call them
+         * directly instead of relying on the base-class implementation which just calls the single-byte read() over
+         * and over. Not implementing these results in awful performance.
+         */
+        @Override
+        public int read(byte[] b) throws IOException {
+            return getRealInputStream().read(b);
+        }
+
+        @Override
+        public int read(byte[] b, int off, int len) throws IOException {
+            return getRealInputStream().read(b, off, len);
+        }
+
+        private InputStream getRealInputStream() throws IOException {
+            if (realInputStream == null) {
+                // Note: The first call to getSocket() triggers mode detection which can block
+                realInputStream = unifiedSocket.getSocket().getInputStream();
+            }
+            return realInputStream;
+        }
+
+        @Override
+        public long skip(long n) throws IOException {
+            return getRealInputStream().skip(n);
+        }
+
+        @Override
+        public int available() throws IOException {
+            return getRealInputStream().available();
+        }
+
+        @Override
+        public void close() throws IOException {
+            getRealInputStream().close();
+        }
+
+        @Override
+        public synchronized void mark(int readlimit) {
+            try {
+                getRealInputStream().mark(readlimit);
+            } catch (IOException e) {
+                throw new RuntimeException(e);
+            }
+        }
+
+        @Override
+        public synchronized void reset() throws IOException {
+            getRealInputStream().reset();
+        }
+
+        @Override
+        public boolean markSupported() {
+            try {
+                return getRealInputStream().markSupported();
+            } catch (IOException e) {
+                throw new RuntimeException(e);
+            }
+        }
+
+    }
+
+    private static class UnifiedOutputStream extends OutputStream {
+        private final UnifiedSocket unifiedSocket;
+        private OutputStream realOutputStream;
+
+        private UnifiedOutputStream(UnifiedSocket unifiedSocket) {
+            this.unifiedSocket = unifiedSocket;
+            this.realOutputStream = null;
+        }
+
+        @Override
+        public void write(int b) throws IOException {
+            getRealOutputStream().write(b);
+        }
+
+        @Override
+        public void write(byte[] b) throws IOException {
+            getRealOutputStream().write(b);
+        }
+
+        @Override
+        public void write(byte[] b, int off, int len) throws IOException {
+            getRealOutputStream().write(b, off, len);
+        }
+
+        @Override
+        public void flush() throws IOException {
+            getRealOutputStream().flush();
+        }
+
+        @Override
+        public void close() throws IOException {
+            getRealOutputStream().close();
+        }
+
+        private OutputStream getRealOutputStream() throws IOException {
+            if (realOutputStream == null) {
+                // Note: The first call to getSocket() triggers mode detection which can block
+                realOutputStream = unifiedSocket.getSocket().getOutputStream();
+            }
+            return realOutputStream;
+        }
+
     }
-}
\ No newline at end of file
+}

http://git-wip-us.apache.org/repos/asf/zookeeper/blob/64104eae/zookeeper-server/src/test/java/org/apache/zookeeper/common/X509UtilTest.java
----------------------------------------------------------------------
diff --git a/zookeeper-server/src/test/java/org/apache/zookeeper/common/X509UtilTest.java b/zookeeper-server/src/test/java/org/apache/zookeeper/common/X509UtilTest.java
index 6b343c3..546cf55 100644
--- a/zookeeper-server/src/test/java/org/apache/zookeeper/common/X509UtilTest.java
+++ b/zookeeper-server/src/test/java/org/apache/zookeeper/common/X509UtilTest.java
@@ -356,6 +356,34 @@ public class X509UtilTest extends BaseX509ParameterizedTestCase {
                 true);
     }
 
+    @Test
+    public void testGetSslHandshakeDetectionTimeoutMillisProperty() {
+        X509Util x509Util = new ClientX509Util();
+        Assert.assertEquals(
+                X509Util.DEFAULT_HANDSHAKE_DETECTION_TIMEOUT_MILLIS,
+                x509Util.getSslHandshakeTimeoutMillis());
+        try {
+            String newPropertyString = Integer.toString(X509Util.DEFAULT_HANDSHAKE_DETECTION_TIMEOUT_MILLIS + 1);
+            System.setProperty(x509Util.getSslHandshakeDetectionTimeoutMillisProperty(), newPropertyString);
+            // Note: need to create a new ClientX509Util to pick up modified property value
+            Assert.assertEquals(
+                    X509Util.DEFAULT_HANDSHAKE_DETECTION_TIMEOUT_MILLIS + 1,
+                    new ClientX509Util().getSslHandshakeTimeoutMillis());
+            // 0 value not allowed, will return the default
+            System.setProperty(x509Util.getSslHandshakeDetectionTimeoutMillisProperty(), "0");
+            Assert.assertEquals(
+                    X509Util.DEFAULT_HANDSHAKE_DETECTION_TIMEOUT_MILLIS,
+                    new ClientX509Util().getSslHandshakeTimeoutMillis());
+            // Negative value not allowed, will return the default
+            System.setProperty(x509Util.getSslHandshakeDetectionTimeoutMillisProperty(), "-1");
+            Assert.assertEquals(
+                    X509Util.DEFAULT_HANDSHAKE_DETECTION_TIMEOUT_MILLIS,
+                    new ClientX509Util().getSslHandshakeTimeoutMillis());
+        } finally {
+            System.clearProperty(x509Util.getSslHandshakeDetectionTimeoutMillisProperty());
+        }
+    }
+
     // Warning: this will reset the x509Util
     private void setCustomCipherSuites() {
         System.setProperty(x509Util.getCipherSuitesProperty(), customCipherSuites[0] + "," + customCipherSuites[1]);

http://git-wip-us.apache.org/repos/asf/zookeeper/blob/64104eae/zookeeper-server/src/test/java/org/apache/zookeeper/server/quorum/QuorumSSLTest.java
----------------------------------------------------------------------
diff --git a/zookeeper-server/src/test/java/org/apache/zookeeper/server/quorum/QuorumSSLTest.java b/zookeeper-server/src/test/java/org/apache/zookeeper/server/quorum/QuorumSSLTest.java
index b088f47..67c15ad 100644
--- a/zookeeper-server/src/test/java/org/apache/zookeeper/server/quorum/QuorumSSLTest.java
+++ b/zookeeper-server/src/test/java/org/apache/zookeeper/server/quorum/QuorumSSLTest.java
@@ -80,7 +80,6 @@ import org.bouncycastle.util.io.pem.PemWriter;
 import org.junit.After;
 import org.junit.Assert;
 import org.junit.Before;
-import org.junit.Ignore;
 import org.junit.Rule;
 import org.junit.Test;
 import org.junit.rules.Timeout;
@@ -442,7 +441,6 @@ public class QuorumSSLTest extends QuorumPeerTestBase {
         Assert.assertFalse(ClientBase.waitForServerUp("127.0.0.1:" + clientPortQp3, CONNECTION_TIMEOUT));
     }
 
-    @Ignore("portUnification is currently broken and disabled")
     @Test
     public void testRollingUpgrade() throws Exception {
         // Form a quorum without ssl

http://git-wip-us.apache.org/repos/asf/zookeeper/blob/64104eae/zookeeper-server/src/test/java/org/apache/zookeeper/server/quorum/UnifiedServerSocketModeDetectionTest.java
----------------------------------------------------------------------
diff --git a/zookeeper-server/src/test/java/org/apache/zookeeper/server/quorum/UnifiedServerSocketModeDetectionTest.java b/zookeeper-server/src/test/java/org/apache/zookeeper/server/quorum/UnifiedServerSocketModeDetectionTest.java
new file mode 100644
index 0000000..61862a4
--- /dev/null
+++ b/zookeeper-server/src/test/java/org/apache/zookeeper/server/quorum/UnifiedServerSocketModeDetectionTest.java
@@ -0,0 +1,404 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.zookeeper.server.quorum;
+
+import java.io.File;
+import java.io.IOException;
+import java.net.InetAddress;
+import java.net.InetSocketAddress;
+import java.net.ServerSocket;
+import java.net.Socket;
+import java.net.SocketOptions;
+import java.security.Security;
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.concurrent.Callable;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
+import java.util.concurrent.Future;
+import java.util.concurrent.TimeUnit;
+
+import org.apache.commons.io.FileUtils;
+import org.apache.zookeeper.PortAssignment;
+import org.apache.zookeeper.ZKTestCase;
+import org.apache.zookeeper.common.ClientX509Util;
+import org.apache.zookeeper.common.KeyStoreFileType;
+import org.apache.zookeeper.common.X509KeyType;
+import org.apache.zookeeper.common.X509TestContext;
+import org.apache.zookeeper.common.X509Util;
+import org.apache.zookeeper.test.ClientBase;
+import org.bouncycastle.jce.provider.BouncyCastleProvider;
+import org.junit.After;
+import org.junit.AfterClass;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.BeforeClass;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * This test makes sure that certain operations on a UnifiedServerSocket do not
+ * trigger blocking mode detection. This is necessary to ensure that the
+ * Leader's accept() thread doesn't get blocked.
+ */
+@RunWith(Parameterized.class)
+public class UnifiedServerSocketModeDetectionTest extends ZKTestCase {
+    private static final Logger LOG = LoggerFactory.getLogger(
+            UnifiedServerSocketModeDetectionTest.class);
+
+    @Parameterized.Parameters
+    public static Collection<Object[]> params() {
+        ArrayList<Object[]> result = new ArrayList<>();
+        result.add(new Object[] { true });
+        result.add(new Object[] { false });
+        return result;
+    }
+
+    private static File tempDir;
+    private static X509TestContext x509TestContext;
+
+    private boolean useSecureClient;
+    private X509Util x509Util;
+    private UnifiedServerSocket listeningSocket;
+    private UnifiedServerSocket.UnifiedSocket serverSideSocket;
+    private Socket clientSocket;
+    private ExecutorService workerPool;
+    private int port;
+    private InetSocketAddress localServerAddress;
+
+    @BeforeClass
+    public static void setUpClass() throws Exception {
+        Security.addProvider(new BouncyCastleProvider());
+        tempDir = ClientBase.createEmptyTestDir();
+        x509TestContext = X509TestContext.newBuilder()
+                .setTempDir(tempDir)
+                .setKeyStoreKeyType(X509KeyType.EC)
+                .setTrustStoreKeyType(X509KeyType.EC)
+                .build();
+    }
+
+    @AfterClass
+    public static void tearDownClass() {
+        try {
+            FileUtils.deleteDirectory(tempDir);
+        } catch (IOException e) {
+            // ignore
+        }
+        Security.removeProvider(BouncyCastleProvider.PROVIDER_NAME);
+    }
+
+    private static void forceClose(Socket s) {
+        if (s == null || s.isClosed()) {
+            return;
+        }
+        try {
+            s.close();
+        } catch (IOException e) {
+        }
+    }
+
+    private static void forceClose(ServerSocket s) {
+        if (s == null || s.isClosed()) {
+            return;
+        }
+        try {
+            s.close();
+        } catch (IOException e) {
+        }
+    }
+
+    public UnifiedServerSocketModeDetectionTest(Boolean useSecureClient) {
+        this.useSecureClient = useSecureClient;
+    }
+
+    @Before
+    public void setUp() throws Exception {
+        x509Util = new ClientX509Util();
+        x509TestContext.setSystemProperties(x509Util, KeyStoreFileType.JKS, KeyStoreFileType.JKS);
+        System.setProperty(x509Util.getSslHandshakeDetectionTimeoutMillisProperty(), "100");
+        workerPool = Executors.newCachedThreadPool();
+        port = PortAssignment.unique();
+        localServerAddress = new InetSocketAddress(InetAddress.getLoopbackAddress(), port);
+        listeningSocket = new UnifiedServerSocket(x509Util, true);
+        listeningSocket.bind(localServerAddress);
+        Future<UnifiedServerSocket.UnifiedSocket> acceptFuture;
+        acceptFuture = workerPool.submit(new Callable<UnifiedServerSocket.UnifiedSocket>() {
+            @Override
+            public UnifiedServerSocket.UnifiedSocket call() throws Exception {
+                try {
+                    return (UnifiedServerSocket.UnifiedSocket) listeningSocket.accept();
+                } catch (IOException e) {
+                    LOG.error("Error in accept(): ", e);
+                    throw e;
+                }
+            }
+        });
+        if (useSecureClient) {
+            clientSocket = x509Util.createSSLSocket();
+            clientSocket.connect(localServerAddress);
+        } else {
+            clientSocket = new Socket();
+            clientSocket.connect(localServerAddress);
+            clientSocket.getOutputStream().write(new byte[] { 1, 2, 3, 4, 5 });
+        }
+        serverSideSocket = acceptFuture.get();
+    }
+
+    @After
+    public void tearDown() throws Exception {
+        x509TestContext.clearSystemProperties(x509Util);
+        System.clearProperty(x509Util.getSslHandshakeDetectionTimeoutMillisProperty());
+        forceClose(listeningSocket);
+        forceClose(serverSideSocket);
+        forceClose(clientSocket);
+        workerPool.shutdown();
+        workerPool.awaitTermination(1000, TimeUnit.MILLISECONDS);
+    }
+
+    @Test
+    public void testGetInetAddress() {
+        serverSideSocket.getInetAddress();
+        Assert.assertFalse(serverSideSocket.isModeKnown());
+    }
+
+    @Test
+    public void testGetLocalAddress() {
+        serverSideSocket.getLocalAddress();
+        Assert.assertFalse(serverSideSocket.isModeKnown());
+    }
+
+    @Test
+    public void testGetPort() {
+        serverSideSocket.getPort();
+        Assert.assertFalse(serverSideSocket.isModeKnown());
+    }
+
+    @Test
+    public void testGetLocalPort() {
+        serverSideSocket.getLocalPort();
+        Assert.assertFalse(serverSideSocket.isModeKnown());
+    }
+
+    @Test
+    public void testGetRemoteSocketAddress() {
+        serverSideSocket.getRemoteSocketAddress();
+        Assert.assertFalse(serverSideSocket.isModeKnown());
+    }
+
+    @Test
+    public void testGetLocalSocketAddress() {
+        serverSideSocket.getLocalSocketAddress();
+        Assert.assertFalse(serverSideSocket.isModeKnown());
+    }
+
+    @Test
+    public void testGetInputStream() throws IOException {
+        serverSideSocket.getInputStream();
+        Assert.assertFalse(serverSideSocket.isModeKnown());
+    }
+
+    @Test
+    public void testGetOutputStream() throws IOException {
+        serverSideSocket.getOutputStream();
+        Assert.assertFalse(serverSideSocket.isModeKnown());
+    }
+
+    @Test
+    public void testGetTcpNoDelay() throws IOException {
+        serverSideSocket.getTcpNoDelay();
+        Assert.assertFalse(serverSideSocket.isModeKnown());
+    }
+
+    @Test
+    public void testSetTcpNoDelay() throws IOException {
+        boolean tcpNoDelay = serverSideSocket.getTcpNoDelay();
+        tcpNoDelay = !tcpNoDelay;
+        serverSideSocket.setTcpNoDelay(tcpNoDelay);
+        Assert.assertFalse(serverSideSocket.isModeKnown());
+        Assert.assertEquals(tcpNoDelay, serverSideSocket.getTcpNoDelay());
+    }
+
+    @Test
+    public void testGetSoLinger() throws IOException {
+        serverSideSocket.getSoLinger();
+        Assert.assertFalse(serverSideSocket.isModeKnown());
+    }
+
+    @Test
+    public void testSetSoLinger() throws IOException {
+        int soLinger = serverSideSocket.getSoLinger();
+        if (soLinger == -1) {
+            // enable it if disabled
+            serverSideSocket.setSoLinger(true, 1);
+            Assert.assertFalse(serverSideSocket.isModeKnown());
+            Assert.assertEquals(1, serverSideSocket.getSoLinger());
+        } else {
+            // disable it if enabled
+            serverSideSocket.setSoLinger(false, -1);
+            Assert.assertFalse(serverSideSocket.isModeKnown());
+            Assert.assertEquals(-1, serverSideSocket.getSoLinger());
+        }
+    }
+
+    @Test
+    public void testGetSoTimeout() throws IOException {
+        serverSideSocket.getSoTimeout();
+        Assert.assertFalse(serverSideSocket.isModeKnown());
+    }
+
+    @Test
+    public void testSetSoTimeout() throws IOException {
+        int timeout = serverSideSocket.getSoTimeout();
+        timeout = timeout + 10;
+        serverSideSocket.setSoTimeout(timeout);
+        Assert.assertFalse(serverSideSocket.isModeKnown());
+        Assert.assertEquals(timeout, serverSideSocket.getSoTimeout());
+    }
+
+    @Test
+    public void testGetSendBufferSize() throws IOException {
+        serverSideSocket.getSendBufferSize();
+        Assert.assertFalse(serverSideSocket.isModeKnown());
+    }
+
+    @Test
+    public void testSetSendBufferSize() throws IOException {
+        serverSideSocket.setSendBufferSize(serverSideSocket.getSendBufferSize() + 1024);
+        Assert.assertFalse(serverSideSocket.isModeKnown());
+        // Note: the new buffer size is a hint and socket implementation
+        // is free to ignore it, so we don't verify that we get back the
+        // same value.
+
+    }
+
+    @Test
+    public void testGetReceiveBufferSize() throws IOException {
+        serverSideSocket.getReceiveBufferSize();
+        Assert.assertFalse(serverSideSocket.isModeKnown());
+    }
+
+    @Test
+    public void testSetReceiveBufferSize() throws IOException {
+        serverSideSocket.setReceiveBufferSize(serverSideSocket.getReceiveBufferSize() + 1024);
+        Assert.assertFalse(serverSideSocket.isModeKnown());
+        // Note: the new buffer size is a hint and socket implementation
+        // is free to ignore it, so we don't verify that we get back the
+        // same value.
+
+    }
+
+    @Test
+    public void testGetKeepAlive() throws IOException {
+        serverSideSocket.getKeepAlive();
+        Assert.assertFalse(serverSideSocket.isModeKnown());
+    }
+
+    @Test
+    public void testSetKeepAlive() throws IOException {
+        boolean keepAlive = serverSideSocket.getKeepAlive();
+        keepAlive = !keepAlive;
+        serverSideSocket.setKeepAlive(keepAlive);
+        Assert.assertFalse(serverSideSocket.isModeKnown());
+        Assert.assertEquals(keepAlive, serverSideSocket.getKeepAlive());
+    }
+
+    @Test
+    public void testGetTrafficClass() throws IOException {
+        serverSideSocket.getTrafficClass();
+        Assert.assertFalse(serverSideSocket.isModeKnown());
+    }
+
+    @Test
+    public void testSetTrafficClass() throws IOException {
+        serverSideSocket.setTrafficClass(SocketOptions.IP_TOS);
+        Assert.assertFalse(serverSideSocket.isModeKnown());
+        // Note: according to the Socket javadocs, setTrafficClass() may be
+        // ignored by socket implementations, so we don't check that the value
+        // we set is returned.
+    }
+
+    @Test
+    public void testGetReuseAddress() throws IOException {
+        serverSideSocket.getReuseAddress();
+        Assert.assertFalse(serverSideSocket.isModeKnown());
+    }
+
+    @Test
+    public void testSetReuseAddress() throws IOException {
+        boolean reuseAddress = serverSideSocket.getReuseAddress();
+        reuseAddress = !reuseAddress;
+        serverSideSocket.setReuseAddress(reuseAddress);
+        Assert.assertFalse(serverSideSocket.isModeKnown());
+        Assert.assertEquals(reuseAddress, serverSideSocket.getReuseAddress());
+    }
+
+    @Test
+    public void testClose() throws IOException {
+        serverSideSocket.close();
+        Assert.assertFalse(serverSideSocket.isModeKnown());
+    }
+
+    @Test
+    public void testShutdownInput() throws IOException {
+        serverSideSocket.shutdownInput();
+        Assert.assertFalse(serverSideSocket.isModeKnown());
+    }
+
+    @Test
+    public void testShutdownOutput() throws IOException {
+        serverSideSocket.shutdownOutput();
+        Assert.assertFalse(serverSideSocket.isModeKnown());
+    }
+
+    @Test
+    public void testIsConnected() {
+        serverSideSocket.isConnected();
+        Assert.assertFalse(serverSideSocket.isModeKnown());
+    }
+
+    @Test
+    public void testIsBound() {
+        serverSideSocket.isBound();
+        Assert.assertFalse(serverSideSocket.isModeKnown());
+    }
+
+    @Test
+    public void testIsClosed() {
+        serverSideSocket.isClosed();
+        Assert.assertFalse(serverSideSocket.isModeKnown());
+    }
+
+    @Test
+    public void testIsInputShutdown() throws IOException {
+        serverSideSocket.isInputShutdown();
+        Assert.assertFalse(serverSideSocket.isModeKnown());
+        serverSideSocket.shutdownInput();
+        Assert.assertTrue(serverSideSocket.isInputShutdown());
+    }
+
+    @Test
+    public void testIsOutputShutdown() throws IOException {
+        serverSideSocket.isOutputShutdown();
+        Assert.assertFalse(serverSideSocket.isModeKnown());
+        serverSideSocket.shutdownOutput();
+        Assert.assertTrue(serverSideSocket.isOutputShutdown());
+    }
+}