You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@thrift.apache.org by ns...@apache.org on 2015/11/23 09:11:05 UTC

[3/4] thrift git commit: THRIFT-3420 C++: TSSLSockets are not interruptable Client: C++ Patch: Martin Haimberger

THRIFT-3420 C++: TSSLSockets are not interruptable
Client: C++
Patch: Martin Haimberger

This closes #690


Project: http://git-wip-us.apache.org/repos/asf/thrift/repo
Commit: http://git-wip-us.apache.org/repos/asf/thrift/commit/0ad6ee95
Tree: http://git-wip-us.apache.org/repos/asf/thrift/tree/0ad6ee95
Diff: http://git-wip-us.apache.org/repos/asf/thrift/diff/0ad6ee95

Branch: refs/heads/master
Commit: 0ad6ee95e002f41dd628d4044f901468f43ffc32
Parents: ae971ce
Author: Martin Haimberger <ma...@thincast.com>
Authored: Fri Nov 13 03:18:50 2015 -0800
Committer: Nobuaki Sukegawa <ns...@apache.org>
Committed: Mon Nov 23 17:09:27 2015 +0900

----------------------------------------------------------------------
 .../src/thrift/transport/TSSLServerSocket.cpp   |   7 +-
 lib/cpp/src/thrift/transport/TSSLSocket.cpp     | 259 +++++++++++++++--
 lib/cpp/src/thrift/transport/TSSLSocket.h       |  46 +++
 lib/cpp/src/thrift/transport/TServerSocket.cpp  |  16 +-
 lib/cpp/src/thrift/transport/TServerSocket.h    |   4 +-
 lib/cpp/test/CMakeLists.txt                     |   7 +-
 lib/cpp/test/Makefile.am                        |   4 +-
 lib/cpp/test/TSSLSocketInterruptTest.cpp        | 283 +++++++++++++++++++
 8 files changed, 595 insertions(+), 31 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/thrift/blob/0ad6ee95/lib/cpp/src/thrift/transport/TSSLServerSocket.cpp
----------------------------------------------------------------------
diff --git a/lib/cpp/src/thrift/transport/TSSLServerSocket.cpp b/lib/cpp/src/thrift/transport/TSSLServerSocket.cpp
index 7e1484d..89423b4 100644
--- a/lib/cpp/src/thrift/transport/TSSLServerSocket.cpp
+++ b/lib/cpp/src/thrift/transport/TSSLServerSocket.cpp
@@ -48,7 +48,12 @@ TSSLServerSocket::TSSLServerSocket(int port,
 }
 
 boost::shared_ptr<TSocket> TSSLServerSocket::createSocket(THRIFT_SOCKET client) {
-  return factory_->createSocket(client);
+  if (interruptableChildren_) {
+      return factory_->createSocket(client, pChildInterruptSockReader_);
+
+  } else {
+      return factory_->createSocket(client);
+  }
 }
 }
 }

http://git-wip-us.apache.org/repos/asf/thrift/blob/0ad6ee95/lib/cpp/src/thrift/transport/TSSLSocket.cpp
----------------------------------------------------------------------
diff --git a/lib/cpp/src/thrift/transport/TSSLSocket.cpp b/lib/cpp/src/thrift/transport/TSSLSocket.cpp
index 98c5326..6e9a4de 100644
--- a/lib/cpp/src/thrift/transport/TSSLSocket.cpp
+++ b/lib/cpp/src/thrift/transport/TSSLSocket.cpp
@@ -28,6 +28,14 @@
 #ifdef HAVE_SYS_SOCKET_H
 #include <sys/socket.h>
 #endif
+#ifdef HAVE_SYS_POLL_H
+#include <sys/poll.h>
+#endif
+#ifdef HAVE_FCNTL_H
+#include <fcntl.h>
+#endif
+
+
 #include <boost/lexical_cast.hpp>
 #include <boost/shared_array.hpp>
 #include <openssl/err.h>
@@ -189,14 +197,28 @@ TSSLSocket::TSSLSocket(boost::shared_ptr<SSLContext> ctx)
   : TSocket(), server_(false), ssl_(NULL), ctx_(ctx) {
 }
 
+TSSLSocket::TSSLSocket(boost::shared_ptr<SSLContext> ctx, boost::shared_ptr<THRIFT_SOCKET> interruptListener)
+        : TSocket(), server_(false), ssl_(NULL), ctx_(ctx) {
+  interruptListener_ = interruptListener;
+}
+
 TSSLSocket::TSSLSocket(boost::shared_ptr<SSLContext> ctx, THRIFT_SOCKET socket)
   : TSocket(socket), server_(false), ssl_(NULL), ctx_(ctx) {
 }
 
+TSSLSocket::TSSLSocket(boost::shared_ptr<SSLContext> ctx, THRIFT_SOCKET socket, boost::shared_ptr<THRIFT_SOCKET> interruptListener)
+        : TSocket(socket, interruptListener), server_(false), ssl_(NULL), ctx_(ctx) {
+}
+
 TSSLSocket::TSSLSocket(boost::shared_ptr<SSLContext> ctx, string host, int port)
   : TSocket(host, port), server_(false), ssl_(NULL), ctx_(ctx) {
 }
 
+TSSLSocket::TSSLSocket(boost::shared_ptr<SSLContext> ctx, string host, int port, boost::shared_ptr<THRIFT_SOCKET> interruptListener)
+        : TSocket(host, port), server_(false), ssl_(NULL), ctx_(ctx) {
+  interruptListener_ = interruptListener;
+}
+
 TSSLSocket::~TSSLSocket() {
   close();
 }
@@ -222,16 +244,32 @@ bool TSSLSocket::peek() {
   checkHandshake();
   int rc;
   uint8_t byte;
-  rc = SSL_peek(ssl_, &byte, 1);
-  if (rc < 0) {
-    int errno_copy = THRIFT_GET_SOCKET_ERROR;
-    string errors;
-    buildErrors(errors, errno_copy);
-    throw TSSLException("SSL_peek: " + errors);
-  }
-  if (rc == 0) {
-    ERR_clear_error();
-  }
+  do {
+    rc = SSL_peek(ssl_, &byte, 1);
+    if (rc < 0) {
+
+      int errno_copy = THRIFT_GET_SOCKET_ERROR;
+      int error = SSL_get_error(ssl_, rc);
+      switch (error) {
+        case SSL_ERROR_SYSCALL:
+          if ((errno_copy != THRIFT_EINTR)
+              || (errno_copy != THRIFT_EAGAIN)) {
+            break;
+          }
+        case SSL_ERROR_WANT_READ:
+        case SSL_ERROR_WANT_WRITE:
+          waitForEvent(error == SSL_ERROR_WANT_READ);
+              continue;
+        default:;// do nothing
+      }
+      string errors;
+      buildErrors(errors, errno_copy);
+      throw TSSLException("SSL_peek: " + errors);
+    } else if (rc == 0) {
+      ERR_clear_error();
+      break;
+    }
+  } while (true);
   return (rc > 0);
 }
 
@@ -244,7 +282,28 @@ void TSSLSocket::open() {
 
 void TSSLSocket::close() {
   if (ssl_ != NULL) {
-    int rc = SSL_shutdown(ssl_);
+    int rc;
+
+    do {
+      rc = SSL_shutdown(ssl_);
+      if (rc <= 0) {
+        int errno_copy = THRIFT_GET_SOCKET_ERROR;
+        int error = SSL_get_error(ssl_, rc);
+        switch (error) {
+          case SSL_ERROR_SYSCALL:
+            if ((errno_copy != THRIFT_EINTR)
+                || (errno_copy != THRIFT_EAGAIN)) {
+              break;
+            }
+          case SSL_ERROR_WANT_READ:
+          case SSL_ERROR_WANT_WRITE:
+            waitForEvent(error == SSL_ERROR_WANT_READ);
+                rc = 2;
+          default:;// do nothing
+        }
+      }
+    } while (rc == 2);
+
     if (rc < 0) {
       int errno_copy = THRIFT_GET_SOCKET_ERROR;
       string errors;
@@ -262,14 +321,36 @@ uint32_t TSSLSocket::read(uint8_t* buf, uint32_t len) {
   checkHandshake();
   int32_t bytes = 0;
   for (int32_t retries = 0; retries < maxRecvRetries_; retries++) {
+    ERR_clear_error();
     bytes = SSL_read(ssl_, buf, len);
     if (bytes >= 0)
       break;
-    int errno_copy = THRIFT_GET_SOCKET_ERROR;
-    if (SSL_get_error(ssl_, bytes) == SSL_ERROR_SYSCALL) {
-      if (ERR_get_error() == 0 && errno_copy == THRIFT_EINTR) {
+    int32_t errno_copy = THRIFT_GET_SOCKET_ERROR;
+    int32_t error = SSL_get_error(ssl_, bytes);
+    switch (error) {
+      case SSL_ERROR_SYSCALL:
+        if ((errno_copy != THRIFT_EINTR)
+            || (errno_copy != THRIFT_EAGAIN)) {
+              break;
+        }
+        if (retries++ >= maxRecvRetries_) {
+          // THRIFT_EINTR needs to be handled manually and we can tolerate
+          // a certain number
+          break;
+        }
+      case SSL_ERROR_WANT_READ:
+      case SSL_ERROR_WANT_WRITE:
+        if (waitForEvent(error == SSL_ERROR_WANT_READ) == TSSL_EINTR ) {
+          // repeat operation
+          if (retries++ < maxRecvRetries_) {
+            // THRIFT_EINTR needs to be handled manually and we can tolerate
+            // a certain number
+            continue;
+          }
+          throw TTransportException(TTransportException::INTERNAL_ERROR, "too much recv retries");
+        }
         continue;
-      }
+      default:;// do nothing
     }
     string errors;
     buildErrors(errors, errno_copy);
@@ -283,9 +364,23 @@ void TSSLSocket::write(const uint8_t* buf, uint32_t len) {
   // loop in case SSL_MODE_ENABLE_PARTIAL_WRITE is set in SSL_CTX.
   uint32_t written = 0;
   while (written < len) {
+    ERR_clear_error();
     int32_t bytes = SSL_write(ssl_, &buf[written], len - written);
     if (bytes <= 0) {
       int errno_copy = THRIFT_GET_SOCKET_ERROR;
+      int error = SSL_get_error(ssl_, bytes);
+      switch (error) {
+        case SSL_ERROR_SYSCALL:
+          if ((errno_copy != THRIFT_EINTR)
+              || (errno_copy != THRIFT_EAGAIN)) {
+            break;
+          }
+        case SSL_ERROR_WANT_READ:
+        case SSL_ERROR_WANT_WRITE:
+          waitForEvent(error == SSL_ERROR_WANT_READ);
+          continue;
+        default:;// do nothing
+      }
       string errors;
       buildErrors(errors, errno_copy);
       throw TSSLException("SSL_write: " + errors);
@@ -319,13 +414,76 @@ void TSSLSocket::checkHandshake() {
   if (ssl_ != NULL) {
     return;
   }
+
+  // set underlying socket to non-blocking
+  int flags;
+  if ((flags = THRIFT_FCNTL(socket_, THRIFT_F_GETFL, 0)) < 0
+      || THRIFT_FCNTL(socket_, THRIFT_F_SETFL, flags | THRIFT_O_NONBLOCK) < 0) {
+    GlobalOutput.perror("thriftServerEventHandler: set THRIFT_O_NONBLOCK (THRIFT_FCNTL) ",
+                        THRIFT_GET_SOCKET_ERROR);
+    ::THRIFT_CLOSESOCKET(socket_);
+    return;
+  }
+
   ssl_ = ctx_->createSSL();
+
+  //set read and write bios to non-blocking
+  BIO* wbio =  BIO_new(BIO_s_mem());
+  if (wbio == NULL) {
+    throw TSSLException("SSL_get_wbio returns NULL");
+  }
+  BIO_set_nbio(wbio, 1);
+
+  BIO* rbio = BIO_new(BIO_s_mem());
+  if (rbio == NULL) {
+    throw TSSLException("SSL_get_rbio returns NULL");
+  }
+  BIO_set_nbio(rbio, 1);
+
+  SSL_set_bio(ssl_, rbio, wbio);
+
   SSL_set_fd(ssl_, static_cast<int>(socket_));
   int rc;
   if (server()) {
-    rc = SSL_accept(ssl_);
+    do {
+      rc = SSL_accept(ssl_);
+      if (rc <= 0) {
+        int errno_copy = THRIFT_GET_SOCKET_ERROR;
+        int error = SSL_get_error(ssl_, rc);
+        switch (error) {
+          case SSL_ERROR_SYSCALL:
+            if ((errno_copy != THRIFT_EINTR)
+                || (errno_copy != THRIFT_EAGAIN)) {
+              break;
+            }
+          case SSL_ERROR_WANT_READ:
+          case SSL_ERROR_WANT_WRITE:
+            waitForEvent(error == SSL_ERROR_WANT_READ);
+            rc = 2;
+          default:;// do nothing
+        }
+      }
+    } while (rc == 2);
   } else {
-    rc = SSL_connect(ssl_);
+    do {
+      rc = SSL_connect(ssl_);
+      if (rc <= 0) {
+        int errno_copy = THRIFT_GET_SOCKET_ERROR;
+        int error = SSL_get_error(ssl_, rc);
+        switch (error) {
+          case SSL_ERROR_SYSCALL:
+            if ((errno_copy != THRIFT_EINTR)
+                || (errno_copy != THRIFT_EAGAIN)) {
+              break;
+            }
+          case SSL_ERROR_WANT_READ:
+          case SSL_ERROR_WANT_WRITE:
+            waitForEvent(error == SSL_ERROR_WANT_READ);
+                rc = 2;
+          default:;// do nothing
+        }
+      }
+    } while (rc == 2);
   }
   if (rc <= 0) {
     int errno_copy = THRIFT_GET_SOCKET_ERROR;
@@ -443,6 +601,54 @@ void TSSLSocket::authorize() {
   }
 }
 
+unsigned int TSSLSocket::waitForEvent(bool wantRead) {
+  int fdSocket;
+  BIO* bio;
+
+  if (wantRead) {
+    bio = SSL_get_rbio(ssl_);
+  } else {
+    bio = SSL_get_wbio(ssl_);
+  }
+
+  if (bio == NULL) {
+    throw TSSLException("SSL_get_?bio returned NULL");
+  }
+
+  if (BIO_get_fd(bio, &fdSocket) <= 0) {
+    throw TSSLException("BIO_get_fd failed");
+  }
+
+  struct THRIFT_POLLFD fds[2];
+  std::memset(fds, 0, sizeof(fds));
+  fds[0].fd = fdSocket;
+  fds[0].events = wantRead ? THRIFT_POLLIN : THRIFT_POLLOUT;
+
+  if (interruptListener_) {
+    fds[1].fd = *(interruptListener_.get());
+    fds[1].events = THRIFT_POLLIN;
+  }
+
+  int ret = THRIFT_POLL(fds, interruptListener_ ? 2 : 1, -1);
+
+  if (ret < 0) {
+    // error cases
+    if (THRIFT_GET_SOCKET_ERROR == THRIFT_EINTR) {
+      return TSSL_EINTR; // repeat operation
+    }
+    int errno_copy = THRIFT_GET_SOCKET_ERROR;
+    GlobalOutput.perror("TSSLSocket::read THRIFT_POLL() ", errno_copy);
+    throw TTransportException(TTransportException::UNKNOWN, "Unknown", errno_copy);
+  } else if (ret > 0){
+    if (fds[1].revents & THRIFT_POLLIN) {
+      throw TTransportException(TTransportException::INTERRUPTED, "Interrupted");
+    }
+    return TSSL_DATA;
+  } else {
+    throw TTransportException(TTransportException::TIMED_OUT, "THRIFT_POLL (timed out)");
+  }
+}
+
 // TSSLSocketFactory implementation
 uint64_t TSSLSocketFactory::count_ = 0;
 Mutex TSSLSocketFactory::mutex_;
@@ -475,18 +681,37 @@ boost::shared_ptr<TSSLSocket> TSSLSocketFactory::createSocket() {
   return ssl;
 }
 
+boost::shared_ptr<TSSLSocket> TSSLSocketFactory::createSocket(boost::shared_ptr<THRIFT_SOCKET> interruptListener) {
+  boost::shared_ptr<TSSLSocket> ssl(new TSSLSocket(ctx_, interruptListener));
+  setup(ssl);
+  return ssl;
+}
+
 boost::shared_ptr<TSSLSocket> TSSLSocketFactory::createSocket(THRIFT_SOCKET socket) {
   boost::shared_ptr<TSSLSocket> ssl(new TSSLSocket(ctx_, socket));
   setup(ssl);
   return ssl;
 }
 
+boost::shared_ptr<TSSLSocket> TSSLSocketFactory::createSocket(THRIFT_SOCKET socket, boost::shared_ptr<THRIFT_SOCKET> interruptListener) {
+  boost::shared_ptr<TSSLSocket> ssl(new TSSLSocket(ctx_, socket, interruptListener));
+  setup(ssl);
+  return ssl;
+}
+
 boost::shared_ptr<TSSLSocket> TSSLSocketFactory::createSocket(const string& host, int port) {
   boost::shared_ptr<TSSLSocket> ssl(new TSSLSocket(ctx_, host, port));
   setup(ssl);
   return ssl;
 }
 
+boost::shared_ptr<TSSLSocket> TSSLSocketFactory::createSocket(const string& host, int port, boost::shared_ptr<THRIFT_SOCKET> interruptListener) {
+  boost::shared_ptr<TSSLSocket> ssl(new TSSLSocket(ctx_, host, port, interruptListener));
+  setup(ssl);
+  return ssl;
+}
+
+
 void TSSLSocketFactory::setup(boost::shared_ptr<TSSLSocket> ssl) {
   ssl->server(server());
   if (access_ == NULL && !server()) {

http://git-wip-us.apache.org/repos/asf/thrift/blob/0ad6ee95/lib/cpp/src/thrift/transport/TSSLSocket.h
----------------------------------------------------------------------
diff --git a/lib/cpp/src/thrift/transport/TSSLSocket.h b/lib/cpp/src/thrift/transport/TSSLSocket.h
index 02d414b..ba8abf4 100644
--- a/lib/cpp/src/thrift/transport/TSSLSocket.h
+++ b/lib/cpp/src/thrift/transport/TSSLSocket.h
@@ -43,6 +43,9 @@ enum SSLProtocol {
   LATEST  = TLSv1_2
 };
 
+#define TSSL_EINTR 0
+#define TSSL_DATA 1
+
 /**
  * Initialize OpenSSL library.  This function, or some other
  * equivalent function to initialize OpenSSL, must be called before
@@ -99,18 +102,35 @@ protected:
    */
   TSSLSocket(boost::shared_ptr<SSLContext> ctx);
   /**
+   * Constructor with an interrupt signal.
+   */
+  TSSLSocket(boost::shared_ptr<SSLContext> ctx, boost::shared_ptr<THRIFT_SOCKET> interruptListener);
+  /**
    * Constructor, create an instance of TSSLSocket given an existing socket.
    *
    * @param socket An existing socket
    */
   TSSLSocket(boost::shared_ptr<SSLContext> ctx, THRIFT_SOCKET socket);
   /**
+   * Constructor, create an instance of TSSLSocket given an existing socket that can be interrupted.
+   *
+   * @param socket An existing socket
+   */
+  TSSLSocket(boost::shared_ptr<SSLContext> ctx, THRIFT_SOCKET socket, boost::shared_ptr<THRIFT_SOCKET> interruptListener);
+   /**
    * Constructor.
    *
    * @param host  Remote host name
    * @param port  Remote port number
    */
   TSSLSocket(boost::shared_ptr<SSLContext> ctx, std::string host, int port);
+    /**
+	* Constructor with an interrupt signal.
+	*
+	* @param host  Remote host name
+	* @param port  Remote port number
+	*/
+    TSSLSocket(boost::shared_ptr<SSLContext> ctx, std::string host, int port, boost::shared_ptr<THRIFT_SOCKET> interruptListener);
   /**
    * Authorize peer access after SSL handshake completes.
    */
@@ -119,6 +139,15 @@ protected:
    * Initiate SSL handshake if not already initiated.
    */
   void checkHandshake();
+  /**
+   * Waits for an socket or shutdown event.
+   *
+   * @throw TTransportException::INTERRUPTED if interrupted is signaled.
+   *
+   * @return TSSL_EINTR if EINTR happened on the underlying socket
+   *         TSSL_DATA  if data is available on the socket.
+   */
+  unsigned int waitForEvent(bool wantRead);
 
   bool server_;
   SSL* ssl_;
@@ -144,12 +173,22 @@ public:
    */
   virtual boost::shared_ptr<TSSLSocket> createSocket();
   /**
+   * Create an instance of TSSLSocket with a fresh new socket, which is interruptable.
+   */
+  virtual boost::shared_ptr<TSSLSocket> createSocket(boost::shared_ptr<THRIFT_SOCKET> interruptListener);
+  /**
    * Create an instance of TSSLSocket with the given socket.
    *
    * @param socket An existing socket.
    */
   virtual boost::shared_ptr<TSSLSocket> createSocket(THRIFT_SOCKET socket);
   /**
+   * Create an instance of TSSLSocket with the given socket which is interruptable.
+   *
+   * @param socket An existing socket.
+   */
+  virtual boost::shared_ptr<TSSLSocket> createSocket(THRIFT_SOCKET socket, boost::shared_ptr<THRIFT_SOCKET> interruptListener);
+  /**
   * Create an instance of TSSLSocket.
   *
   * @param host  Remote host to be connected to
@@ -157,6 +196,13 @@ public:
   */
   virtual boost::shared_ptr<TSSLSocket> createSocket(const std::string& host, int port);
   /**
+  * Create an instance of TSSLSocket.
+  *
+  * @param host  Remote host to be connected to
+  * @param port  Remote port to be connected to
+  */
+  virtual boost::shared_ptr<TSSLSocket> createSocket(const std::string& host, int port, boost::shared_ptr<THRIFT_SOCKET> interruptListener);
+  /**
    * Set ciphers to be used in SSL handshake process.
    *
    * @param ciphers  A list of ciphers

http://git-wip-us.apache.org/repos/asf/thrift/blob/0ad6ee95/lib/cpp/src/thrift/transport/TServerSocket.cpp
----------------------------------------------------------------------
diff --git a/lib/cpp/src/thrift/transport/TServerSocket.cpp b/lib/cpp/src/thrift/transport/TServerSocket.cpp
index 215cda6..137dc32 100644
--- a/lib/cpp/src/thrift/transport/TServerSocket.cpp
+++ b/lib/cpp/src/thrift/transport/TServerSocket.cpp
@@ -119,7 +119,8 @@ using namespace std;
 using boost::shared_ptr;
 
 TServerSocket::TServerSocket(int port)
-  : port_(port),
+  : interruptableChildren_(true),
+    port_(port),
     serverSocket_(THRIFT_INVALID_SOCKET),
     acceptBacklog_(DEFAULT_BACKLOG),
     sendTimeout_(0),
@@ -130,7 +131,6 @@ TServerSocket::TServerSocket(int port)
     tcpSendBuffer_(0),
     tcpRecvBuffer_(0),
     keepAlive_(false),
-    interruptableChildren_(true),
     listening_(false),
     interruptSockWriter_(THRIFT_INVALID_SOCKET),
     interruptSockReader_(THRIFT_INVALID_SOCKET),
@@ -138,7 +138,8 @@ TServerSocket::TServerSocket(int port)
 }
 
 TServerSocket::TServerSocket(int port, int sendTimeout, int recvTimeout)
-  : port_(port),
+  : interruptableChildren_(true),
+    port_(port),
     serverSocket_(THRIFT_INVALID_SOCKET),
     acceptBacklog_(DEFAULT_BACKLOG),
     sendTimeout_(sendTimeout),
@@ -149,7 +150,6 @@ TServerSocket::TServerSocket(int port, int sendTimeout, int recvTimeout)
     tcpSendBuffer_(0),
     tcpRecvBuffer_(0),
     keepAlive_(false),
-    interruptableChildren_(true),
     listening_(false),
     interruptSockWriter_(THRIFT_INVALID_SOCKET),
     interruptSockReader_(THRIFT_INVALID_SOCKET),
@@ -157,7 +157,8 @@ TServerSocket::TServerSocket(int port, int sendTimeout, int recvTimeout)
 }
 
 TServerSocket::TServerSocket(const string& address, int port)
-  : port_(port),
+  : interruptableChildren_(true),
+    port_(port),
     address_(address),
     serverSocket_(THRIFT_INVALID_SOCKET),
     acceptBacklog_(DEFAULT_BACKLOG),
@@ -169,7 +170,6 @@ TServerSocket::TServerSocket(const string& address, int port)
     tcpSendBuffer_(0),
     tcpRecvBuffer_(0),
     keepAlive_(false),
-    interruptableChildren_(true),
     listening_(false),
     interruptSockWriter_(THRIFT_INVALID_SOCKET),
     interruptSockReader_(THRIFT_INVALID_SOCKET),
@@ -177,7 +177,8 @@ TServerSocket::TServerSocket(const string& address, int port)
 }
 
 TServerSocket::TServerSocket(const string& path)
-  : port_(0),
+  : interruptableChildren_(true),
+    port_(0),
     path_(path),
     serverSocket_(THRIFT_INVALID_SOCKET),
     acceptBacklog_(DEFAULT_BACKLOG),
@@ -189,7 +190,6 @@ TServerSocket::TServerSocket(const string& path)
     tcpSendBuffer_(0),
     tcpRecvBuffer_(0),
     keepAlive_(false),
-    interruptableChildren_(true),
     listening_(false),
     interruptSockWriter_(THRIFT_INVALID_SOCKET),
     interruptSockReader_(THRIFT_INVALID_SOCKET),

http://git-wip-us.apache.org/repos/asf/thrift/blob/0ad6ee95/lib/cpp/src/thrift/transport/TServerSocket.h
----------------------------------------------------------------------
diff --git a/lib/cpp/src/thrift/transport/TServerSocket.h b/lib/cpp/src/thrift/transport/TServerSocket.h
index 58e4afd..20a37e7 100644
--- a/lib/cpp/src/thrift/transport/TServerSocket.h
+++ b/lib/cpp/src/thrift/transport/TServerSocket.h
@@ -123,6 +123,8 @@ public:
 protected:
   boost::shared_ptr<TTransport> acceptImpl();
   virtual boost::shared_ptr<TSocket> createSocket(THRIFT_SOCKET client);
+  bool interruptableChildren_;
+  boost::shared_ptr<THRIFT_SOCKET> pChildInterruptSockReader_; // if interruptableChildren_ this is shared with child TSockets
 
 private:
   void notify(THRIFT_SOCKET notifySock);
@@ -140,13 +142,11 @@ private:
   int tcpSendBuffer_;
   int tcpRecvBuffer_;
   bool keepAlive_;
-  bool interruptableChildren_;
   bool listening_;
 
   THRIFT_SOCKET interruptSockWriter_;                          // is notified on interrupt()
   THRIFT_SOCKET interruptSockReader_;                          // is used in select/poll with serverSocket_ for interruptability
   THRIFT_SOCKET childInterruptSockWriter_;                     // is notified on interruptChildren()
-  boost::shared_ptr<THRIFT_SOCKET> pChildInterruptSockReader_; // if interruptableChildren_ this is shared with child TSockets
 
   socket_func_t listenCallback_;
   socket_func_t acceptCallback_;

http://git-wip-us.apache.org/repos/asf/thrift/blob/0ad6ee95/lib/cpp/test/CMakeLists.txt
----------------------------------------------------------------------
diff --git a/lib/cpp/test/CMakeLists.txt b/lib/cpp/test/CMakeLists.txt
index 02932cb..491d343 100644
--- a/lib/cpp/test/CMakeLists.txt
+++ b/lib/cpp/test/CMakeLists.txt
@@ -92,7 +92,10 @@ if ( MSVC )
 endif ( MSVC )
 
 
-set( TInterruptTest_SOURCES TSocketInterruptTest.cpp )
+set( TInterruptTest_SOURCES
+     TSocketInterruptTest.cpp
+     TSSLSocketInterruptTest.cpp
+)
 if (WIN32)
     list(APPEND TInterruptTest_SOURCES
         TPipeInterruptTest.cpp
@@ -108,7 +111,7 @@ LINK_AGAINST_THRIFT_LIBRARY(TInterruptTest thrift)
 if (NOT MSVC AND NOT ${CMAKE_SYSTEM_NAME} MATCHES "Darwin")
 target_link_libraries(TInterruptTest -lrt)
 endif ()
-add_test(NAME TInterruptTest COMMAND TInterruptTest)
+add_test(NAME TInterruptTest COMMAND TInterruptTest "${CMAKE_CURRENT_SOURCE_DIR}/../../../test/keys")
 
 add_executable(TServerIntegrationTest TServerIntegrationTest.cpp)
 target_link_libraries(TServerIntegrationTest

http://git-wip-us.apache.org/repos/asf/thrift/blob/0ad6ee95/lib/cpp/test/Makefile.am
----------------------------------------------------------------------
diff --git a/lib/cpp/test/Makefile.am b/lib/cpp/test/Makefile.am
index 4b7a99f..1895afc 100755
--- a/lib/cpp/test/Makefile.am
+++ b/lib/cpp/test/Makefile.am
@@ -128,11 +128,13 @@ UnitTests_LDADD = \
   $(BOOST_TEST_LDADD)
 
 TInterruptTest_SOURCES = \
-	TSocketInterruptTest.cpp
+	TSocketInterruptTest.cpp \
+	TSSLSocketInterruptTest.cpp
 
 TInterruptTest_LDADD = \
   libtestgencpp.la \
   $(BOOST_TEST_LDADD) \
+  $(BOOST_FILESYSTEM_LDADD) \
   $(BOOST_CHRONO_LDADD) \
   $(BOOST_SYSTEM_LDADD) \
   $(BOOST_THREAD_LDADD)

http://git-wip-us.apache.org/repos/asf/thrift/blob/0ad6ee95/lib/cpp/test/TSSLSocketInterruptTest.cpp
----------------------------------------------------------------------
diff --git a/lib/cpp/test/TSSLSocketInterruptTest.cpp b/lib/cpp/test/TSSLSocketInterruptTest.cpp
new file mode 100644
index 0000000..c723d0e
--- /dev/null
+++ b/lib/cpp/test/TSSLSocketInterruptTest.cpp
@@ -0,0 +1,283 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#include <boost/test/auto_unit_test.hpp>
+#include <boost/test/unit_test_suite.hpp>
+#include <boost/bind.hpp>
+#include <boost/chrono/duration.hpp>
+#include <boost/date_time/posix_time/posix_time_duration.hpp>
+#include <boost/thread/thread.hpp>
+#include <boost/filesystem.hpp>
+#include <boost/format.hpp>
+#include <boost/shared_ptr.hpp>
+#include <thrift/transport/TSSLSocket.h>
+#include <thrift/transport/TSSLServerSocket.h>
+#include "TestPortFixture.h"
+#ifdef __linux__
+#include <signal.h>
+#endif
+
+using apache::thrift::transport::TSSLServerSocket;
+using apache::thrift::transport::TSSLSocket;
+using apache::thrift::transport::TTransport;
+using apache::thrift::transport::TTransportException;
+using apache::thrift::transport::TSSLSocketFactory;
+
+boost::filesystem::path keyDir;
+boost::filesystem::path certFile(const std::string& filename)
+{
+  return keyDir / filename;
+}
+boost::mutex gMutex;
+
+struct GlobalFixtureSSL
+{
+    GlobalFixtureSSL()
+    {
+      using namespace boost::unit_test::framework;
+      for (int i = 0; i < master_test_suite().argc; ++i)
+      {
+        BOOST_TEST_MESSAGE(boost::format("argv[%1%] = \"%2%\"") % i % master_test_suite().argv[i]);
+      }
+
+#ifdef __linux__
+      // OpenSSL calls send() without MSG_NOSIGPIPE so writing to a socket that has
+		// disconnected can cause a SIGPIPE signal...
+		signal(SIGPIPE, SIG_IGN);
+#endif
+
+      TSSLSocketFactory::setManualOpenSSLInitialization(true);
+      apache::thrift::transport::initializeOpenSSL();
+
+      keyDir = boost::filesystem::current_path().parent_path().parent_path().parent_path() / "test" / "keys";
+      if (!boost::filesystem::exists(certFile("server.crt")))
+      {
+        keyDir = boost::filesystem::path(master_test_suite().argv[master_test_suite().argc - 1]);
+        if (!boost::filesystem::exists(certFile("server.crt")))
+        {
+          throw std::invalid_argument("The last argument to this test must be the directory containing the test certificate(s).");
+        }
+      }
+    }
+
+    virtual ~GlobalFixtureSSL()
+    {
+      apache::thrift::transport::cleanupOpenSSL();
+#ifdef __linux__
+      signal(SIGPIPE, SIG_DFL);
+#endif
+    }
+};
+
+#if (BOOST_VERSION >= 105900)
+BOOST_GLOBAL_FIXTURE(GlobalFixtureSSL);
+#else
+BOOST_GLOBAL_FIXTURE(GlobalFixtureSSL)
+#endif
+
+BOOST_FIXTURE_TEST_SUITE(TSSLSocketInterruptTest, TestPortFixture)
+
+void readerWorker(boost::shared_ptr<TTransport> tt, uint32_t expectedResult) {
+  uint8_t buf[4];
+  try {
+    tt->read(buf, 1);
+    BOOST_CHECK_EQUAL(expectedResult, tt->read(buf, 4));
+  } catch (const TTransportException& tx) {
+    BOOST_CHECK_EQUAL(TTransportException::INTERNAL_ERROR, tx.getType());
+  }
+}
+
+void readerWorkerMustThrow(boost::shared_ptr<TTransport> tt) {
+  try {
+    uint8_t buf[400];
+    tt->read(buf, 1);
+    tt->read(buf, 400);
+    BOOST_ERROR("should not have gotten here");
+  } catch (const TTransportException& tx) {
+    BOOST_CHECK_EQUAL(TTransportException::INTERRUPTED, tx.getType());
+  }
+}
+
+boost::shared_ptr<TSSLSocketFactory> createServerSocketFactory() {
+  boost::shared_ptr<TSSLSocketFactory> pServerSocketFactory;
+
+  pServerSocketFactory.reset(new TSSLSocketFactory());
+  pServerSocketFactory->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
+  pServerSocketFactory->loadCertificate(certFile("server.crt").native().c_str());
+  pServerSocketFactory->loadPrivateKey(certFile("server.key").native().c_str());
+  pServerSocketFactory->server(true);
+  return pServerSocketFactory;
+}
+
+boost::shared_ptr<TSSLSocketFactory> createClientSocketFactory() {
+  boost::shared_ptr<TSSLSocketFactory> pClientSocketFactory;
+
+  pClientSocketFactory.reset(new TSSLSocketFactory());
+  pClientSocketFactory->authenticate(true);
+  pClientSocketFactory->loadCertificate(certFile("client.crt").native().c_str());
+  pClientSocketFactory->loadPrivateKey(certFile("client.key").native().c_str());
+  pClientSocketFactory->loadTrustedCertificates(certFile("CA.pem").native().c_str());
+  return pClientSocketFactory;
+}
+
+BOOST_AUTO_TEST_CASE(test_ssl_interruptable_child_read_while_handshaking) {
+  boost::shared_ptr<TSSLSocketFactory> pServerSocketFactory = createServerSocketFactory();
+  TSSLServerSocket sock1("localhost", m_serverPort, pServerSocketFactory);
+  sock1.listen();
+  boost::shared_ptr<TSSLSocketFactory> pClientSocketFactory = createClientSocketFactory();
+  boost::shared_ptr<TSSLSocket> clientSock = pClientSocketFactory->createSocket("localhost", m_serverPort);
+  clientSock->open();
+  boost::shared_ptr<TTransport> accepted = sock1.accept();
+  boost::thread readThread(boost::bind(readerWorkerMustThrow, accepted));
+  boost::this_thread::sleep(boost::posix_time::milliseconds(50));
+  // readThread is practically guaranteed to be blocking now
+  sock1.interruptChildren();
+  BOOST_CHECK_MESSAGE(readThread.try_join_for(boost::chrono::milliseconds(20)),
+  "server socket interruptChildren did not interrupt child read");
+  clientSock->close();
+  accepted->close();
+  sock1.close();
+}
+
+BOOST_AUTO_TEST_CASE(test_ssl_interruptable_child_read) {
+  boost::shared_ptr<TSSLSocketFactory> pServerSocketFactory = createServerSocketFactory();
+  TSSLServerSocket sock1("localhost", m_serverPort, pServerSocketFactory);
+  sock1.listen();
+  boost::shared_ptr<TSSLSocketFactory> pClientSocketFactory = createClientSocketFactory();
+  boost::shared_ptr<TSSLSocket> clientSock = pClientSocketFactory->createSocket("localhost", m_serverPort);
+  clientSock->open();
+  boost::shared_ptr<TTransport> accepted = sock1.accept();
+  boost::thread readThread(boost::bind(readerWorkerMustThrow, accepted));
+  clientSock->write((const uint8_t*)"0", 1);
+  boost::this_thread::sleep(boost::posix_time::milliseconds(50));
+  // readThread is practically guaranteed to be blocking now
+  sock1.interruptChildren();
+  BOOST_CHECK_MESSAGE(readThread.try_join_for(boost::chrono::milliseconds(20)),
+                      "server socket interruptChildren did not interrupt child read");
+  accepted->close();
+  clientSock->close();
+  sock1.close();
+}
+
+BOOST_AUTO_TEST_CASE(test_ssl_non_interruptable_child_read) {
+  std::cout << "An error message from SSL_Shutdown on the console is expected:" << std::endl;
+  boost::shared_ptr<TSSLSocketFactory> pServerSocketFactory = createServerSocketFactory();
+  TSSLServerSocket sock1("localhost", m_serverPort, pServerSocketFactory);
+  sock1.setInterruptableChildren(false); // returns to pre-THRIFT-2441 behavior
+  sock1.listen();
+  boost::shared_ptr<TSSLSocketFactory> pClientSocketFactory = createClientSocketFactory();
+  boost::shared_ptr<TSSLSocket> clientSock = pClientSocketFactory->createSocket("localhost", m_serverPort);
+  clientSock->open();
+  boost::shared_ptr<TTransport> accepted = sock1.accept();
+  boost::thread readThread(boost::bind(readerWorker, accepted, 0));
+  clientSock->write((const uint8_t*)"0", 1);
+  boost::this_thread::sleep(boost::posix_time::milliseconds(50));
+  // readThread is practically guaranteed to be blocking here
+  sock1.interruptChildren();
+  BOOST_CHECK_MESSAGE(!readThread.try_join_for(boost::chrono::milliseconds(200)),
+                      "server socket interruptChildren interrupted child read");
+
+  // only way to proceed is to have the client disconnect
+  clientSock->close();
+  readThread.join();
+  accepted->close();
+  sock1.close();
+}
+
+BOOST_AUTO_TEST_CASE(test_ssl_cannot_change_after_listen) {
+  boost::shared_ptr<TSSLSocketFactory> pServerSocketFactory = createServerSocketFactory();
+  TSSLServerSocket sock1("localhost", m_serverPort, pServerSocketFactory);
+  sock1.listen();
+  BOOST_CHECK_THROW(sock1.setInterruptableChildren(false), std::logic_error);
+  sock1.close();
+}
+
+void peekerWorker(boost::shared_ptr<TTransport> tt, bool expectedResult) {
+  uint8_t buf[400];
+
+  tt->read(buf, 1);
+  BOOST_CHECK_EQUAL(expectedResult, tt->peek());
+}
+
+void peekerWorkerInterrupt(boost::shared_ptr<TTransport> tt) {
+  uint8_t buf[400];
+  try {
+    tt->read(buf, 1);
+    tt->peek();
+  } catch (const TTransportException& tx) {
+    BOOST_CHECK_EQUAL(TTransportException::INTERRUPTED, tx.getType());
+  }
+}
+
+BOOST_AUTO_TEST_CASE(test_ssl_interruptable_child_peek) {
+  std::cout << "An error message from SSL_Shutdown on the console is expected:" << std::endl;
+  boost::shared_ptr<TSSLSocketFactory> pServerSocketFactory = createServerSocketFactory();
+  TSSLServerSocket sock1("localhost", m_serverPort, pServerSocketFactory);
+  sock1.listen();
+  boost::shared_ptr<TSSLSocketFactory> pClientSocketFactory = createClientSocketFactory();
+  boost::shared_ptr<TSSLSocket> clientSock = pClientSocketFactory->createSocket("localhost", m_serverPort);
+  clientSock->open();
+  boost::shared_ptr<TTransport> accepted = sock1.accept();
+  // peek() will return false if child is interrupted
+  boost::thread peekThread(boost::bind(peekerWorkerInterrupt, accepted));
+  clientSock->write((const uint8_t*)"0", 1);
+  boost::this_thread::sleep(boost::posix_time::milliseconds(50));
+  // peekThread is practically guaranteed to be blocking now
+  sock1.interruptChildren();
+  BOOST_CHECK_MESSAGE(peekThread.try_join_for(boost::chrono::milliseconds(200)),
+                      "server socket interruptChildren did not interrupt child peek");
+#ifdef __linux__
+  signal(SIGPIPE, SIG_IGN);
+#endif
+  clientSock->close();
+  accepted->close();
+  sock1.close();
+}
+
+BOOST_AUTO_TEST_CASE(test_ssl_non_interruptable_child_peek) {
+  std::cout << "An error message from SSL_Shutdown on the console is expected:" << std::endl;
+  boost::shared_ptr<TSSLSocketFactory> pServerSocketFactory = createServerSocketFactory();
+  TSSLServerSocket sock1("localhost", m_serverPort, pServerSocketFactory);
+  sock1.setInterruptableChildren(false); // returns to pre-THRIFT-2441 behavior
+  sock1.listen();
+  boost::shared_ptr<TSSLSocketFactory> pClientSocketFactory = createClientSocketFactory();
+  boost::shared_ptr<TSSLSocket> clientSock = pClientSocketFactory->createSocket("localhost", m_serverPort);
+  clientSock->open();
+  boost::shared_ptr<TTransport> accepted = sock1.accept();
+  // peek() will return false when remote side is closed
+  boost::thread peekThread(boost::bind(peekerWorker, accepted, false));
+  //boost::thread peekThread(boost::bind(peekerWorkerRead, clientSock, false));
+  clientSock->write((const uint8_t*)"0", 1);
+  boost::this_thread::sleep(boost::posix_time::milliseconds(50));
+  // peekThread is practically guaranteed to be blocking now
+  sock1.interruptChildren();
+  BOOST_CHECK_MESSAGE(!peekThread.try_join_for(boost::chrono::milliseconds(200)),
+                      "server socket interruptChildren interrupted child peek");
+
+  // only way to proceed is to have the client disconnect
+#ifdef __linux__
+  signal(SIGPIPE, SIG_IGN);
+#endif
+  clientSock->close();
+  peekThread.join();
+  accepted->close();
+  sock1.close();
+}
+
+BOOST_AUTO_TEST_SUITE_END()