You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@thrift.apache.org by br...@apache.org on 2011/02/22 19:12:06 UTC

svn commit: r1073441 - in /thrift/trunk: ./ lib/cpp/ lib/cpp/src/transport/ test/cpp/src/

Author: bryanduxbury
Date: Tue Feb 22 18:12:06 2011
New Revision: 1073441

URL: http://svn.apache.org/viewvc?rev=1073441&view=rev
Log:
THRIFT-151. cpp: TSSLServerSocket and TSSLSocket implementation

This patch adds an implementation of the above ssl sockets.

Patch: Ping Li, Kevin Worth, Rowan Kerr

Added:
    thrift/trunk/README.SSL
    thrift/trunk/lib/cpp/src/transport/TSSLServerSocket.cpp
    thrift/trunk/lib/cpp/src/transport/TSSLServerSocket.h
    thrift/trunk/lib/cpp/src/transport/TSSLSocket.cpp
    thrift/trunk/lib/cpp/src/transport/TSSLSocket.h
Modified:
    thrift/trunk/lib/cpp/Makefile.am
    thrift/trunk/lib/cpp/src/transport/TServerSocket.cpp
    thrift/trunk/lib/cpp/src/transport/TServerSocket.h
    thrift/trunk/lib/cpp/src/transport/TSocket.h
    thrift/trunk/test/cpp/src/TestClient.cpp
    thrift/trunk/test/cpp/src/TestServer.cpp

Added: thrift/trunk/README.SSL
URL: http://svn.apache.org/viewvc/thrift/trunk/README.SSL?rev=1073441&view=auto
==============================================================================
--- thrift/trunk/README.SSL (added)
+++ thrift/trunk/README.SSL Tue Feb 22 18:12:06 2011
@@ -0,0 +1,135 @@
+Notes on Thrift/SSL
+
+Author: Ping Li <pi...@facebook.com>
+
+1. Scope
+
+   This SSL only supports blocking mode socket I/O. It can only be used with
+   TSimpleServer, TThreadedServer, and TThreadPoolServer.
+
+2. Implementation
+
+   There're two main classes TSSLSocketFactory and TSSLSocket. Instances of
+   TSSLSocket are always created from TSSLSocketFactory.
+
+   PosixSSLThreadFactory creates PosixSSLThread. The only difference from the
+   PthreadThread type is that it cleanups OpenSSL error queue upon exiting
+   the thread. Ideally, OpenSSL APIs should only be called from PosixSSLThread.
+
+3. How to use SSL APIs
+
+   // This is for demo. In real code, typically only one TSSLSocketFactory
+   // instance is needed.
+   shared_ptr<TSSLSocketFactory> getSSLSocketFactory() {
+     shared_ptr<TSSLSocketFactory> factory(new TSSLSocketFactory());
+     // client: load trusted certificates
+     factory->loadTrustedCertificates("my-trusted-ca-certificates.pem");
+     // client: optionally set your own access manager, otherwise,
+     //         the default client access manager will be loaded.
+
+     factory->loadCertificate("my-certificate-signed-by-ca.pem");
+     factory->loadPrivateKey("my-private-key.pem");
+     // server: optionally setup access manager
+     // shared_ptr<AccessManager> accessManager(new MyAccessManager);
+     // factory->access(acessManager);
+     ...
+   }
+
+   // client code sample
+   shared_ptr<TSSLSocketFactory> factory = getSSLScoketFactory();
+   shared_ptr<TSocket> socket = factory.createSocket(host, port);
+   shared_ptr<TBufferedTransport> transport(new TBufferedTransport(socket));
+   ...
+
+   // server code sample
+   shared_ptr<TSSLSocketFactory> factory = getSSLSocketFactory();
+   shared_ptr<TSSLServerSocket> socket(new TSSLServerSocket(port, factory));
+   shared_ptr<TTransportFactory> transportFactory(new TBufferedTransportFactory));
+   ...
+
+4. AccessManager
+
+   AccessManager defines a callback interface. It has three callback methods:
+
+   (a) Decision verify(const sockaddr_storage& sa);
+   (b) Decision verify(const string& host, const char* name, int size);
+   (c) Decision verify(const sockaddr_storage& sa, const char* data, int size);
+
+   After SSL handshake completes, additional checks are conducted. Application
+   is given the chance to decide whether or not to continue the conversation
+   with the remote. Application is inqueried through the above three "verify"
+   method. They are called at different points of the verification process.
+
+   Decisions can be one of ALLOW, DENY, and SKIP. ALLOW and DENY means the
+   conversation should be continued or disconnected, respectively. ALLOW and
+   DENY decision stops the verification process. SKIP means there's no decision
+   based on the given input, continue the verification process.
+
+   First, (a) is called with the remote IP. It is called once at the beginning.
+   "sa" is the IP address of the remote peer.
+
+   Then, the certificate of remote peer is loaded. SubjectAltName extensions
+   are extracted and sent to application for verification. When a DNS
+   subjectAltName field is extracted, (b) is called. When an IP subjectAltName
+   field is extracted, (c) is called.
+
+   The "host" in (b) is the value from TSocket::getHost() if this is a client
+   side socket, or TScoket::getPeerHost() if this is a server side socket. The
+   reason is client side socket initiates the connection. TSocket::getHost()
+   is the remote host name. On server side, the remote host name is unknown
+   unless it's retrieved through TSocket::getPeerHost(). Either way, "host"
+   should be the remote host name. Keep in mind, if TSocket::getPeerHost()
+   failed, it would return the remote host name in numeric format.
+
+   If all subjectAltName extensions were "skipped", the common name field would
+   be checked. It is sent to application through (c), where "sa" is the remote
+   IP address. "data" is the IP address extracted from subjectAltName IP
+   extension, and "size" is the length of the extension data.
+
+   If any of the above "verify" methods returned a decision ALLOW or DENY, the
+   verification process would be stopped.
+
+   If any of the above "verify" methods returned SKIP, that decision would be
+   ignored and the verification process would move on till the last item is
+   examined. At that point, if there's still no decision, the connection is
+   terminated.
+
+   Thread safety, an access manager should not store state information if it's
+   to be used by many SSL sockets.
+
+5. SIGPIPE signal
+
+   Applications running OpenSSL over network connections may crash if SIGPIPE
+   is not ignored. This happens when they receive a connection reset by remote
+   peer exception, which somehow triggers a SIGPIPE signal. If not handled,
+   this signal would kill the application.
+
+6. How to run test client/server in SSL mode
+
+   The server expects the followings from the current working directory,
+   - "server-certificate.pem"
+   - "server-private-key.pem"
+
+   The client loads "trusted-ca-certificate.pem" from current directory.
+
+   The file names are hard coded in the source code. You need to create these
+   certificates before you can run the test code in SSL mode. Make sure at least
+   one of the followings is included in "server-certificate.pem",
+   - subjectAltName, DNS localhost
+   - subjectAltName, IP  127.0.0.1
+   - common name,    localhost
+
+   Run,
+   - "./test_server --ssl" to start server
+   - "./test_client --ssl" to run client
+
+   If "-h <host>" is used to run client, the above "localhost" in the above
+   server-certificate.pem has to be replaced with that host name.
+
+7. TSSLSocketFactory::randomize()
+
+   The default implementation of OpenSSLSocketFactory::randomize() simply calls
+   OpenSSL's RAND_poll() when OpenSSL library is first initialized.
+
+   The PRNG seed is key to the application security. This method should be
+   overriden if it's not strong enough for you.

Modified: thrift/trunk/lib/cpp/Makefile.am
URL: http://svn.apache.org/viewvc/thrift/trunk/lib/cpp/Makefile.am?rev=1073441&r1=1073440&r2=1073441&view=diff
==============================================================================
--- thrift/trunk/lib/cpp/Makefile.am (original)
+++ thrift/trunk/lib/cpp/Makefile.am Tue Feb 22 18:12:06 2011
@@ -60,8 +60,10 @@ libthrift_la_SOURCES = src/Thrift.cpp \
                        src/transport/THttpClient.cpp \
                        src/transport/THttpServer.cpp \
                        src/transport/TSocket.cpp \
+                       src/transport/TSSLSocket.cpp \
                        src/transport/TSocketPool.cpp \
                        src/transport/TServerSocket.cpp \
+                       src/transport/TSSLServerSocket.cpp \
                        src/transport/TTransportUtils.cpp \
                        src/transport/TBufferTransports.cpp \
                        src/server/TServer.cpp \
@@ -125,11 +127,13 @@ include_transport_HEADERS = \
                          src/transport/TFileTransport.h \
                          src/transport/TSimpleFileTransport.h \
                          src/transport/TServerSocket.h \
+                         src/transport/TSSLServerSocket.h \
                          src/transport/TServerTransport.h \
                          src/transport/THttpTransport.h \
                          src/transport/THttpClient.h \
                          src/transport/THttpServer.h \
                          src/transport/TSocket.h \
+                         src/transport/TSSLSocket.h \
                          src/transport/TSocketPool.h \
                          src/transport/TVirtualTransport.h \
                          src/transport/TTransport.h \

Added: thrift/trunk/lib/cpp/src/transport/TSSLServerSocket.cpp
URL: http://svn.apache.org/viewvc/thrift/trunk/lib/cpp/src/transport/TSSLServerSocket.cpp?rev=1073441&view=auto
==============================================================================
--- thrift/trunk/lib/cpp/src/transport/TSSLServerSocket.cpp (added)
+++ thrift/trunk/lib/cpp/src/transport/TSSLServerSocket.cpp Tue Feb 22 18:12:06 2011
@@ -0,0 +1,36 @@
+// Copyright (c) 2009- Facebook
+// Distributed under the Thrift Software License
+//
+// See accompanying file LICENSE or visit the Thrift site at:
+// http://developers.facebook.com/thrift/
+
+#include "TSSLServerSocket.h"
+#include "TSSLSocket.h"
+
+namespace apache { namespace thrift { namespace transport {
+
+using namespace boost;
+
+/**
+ * SSL server socket implementation.
+ *
+ * @author Ping Li <pi...@facebook.com>
+ */
+TSSLServerSocket::TSSLServerSocket(int port,
+                                   shared_ptr<TSSLSocketFactory> factory):
+                                   TServerSocket(port), factory_(factory) {
+  factory_->server(true);
+}
+
+TSSLServerSocket::TSSLServerSocket(int port, int sendTimeout, int recvTimeout,
+                                   shared_ptr<TSSLSocketFactory> factory):
+                                   TServerSocket(port, sendTimeout, recvTimeout),
+                                   factory_(factory) {
+  factory_->server(true);
+}
+
+shared_ptr<TSocket> TSSLServerSocket::createSocket(int client) {
+  return factory_->createSocket(client);
+}
+
+}}}

Added: thrift/trunk/lib/cpp/src/transport/TSSLServerSocket.h
URL: http://svn.apache.org/viewvc/thrift/trunk/lib/cpp/src/transport/TSSLServerSocket.h?rev=1073441&view=auto
==============================================================================
--- thrift/trunk/lib/cpp/src/transport/TSSLServerSocket.h (added)
+++ thrift/trunk/lib/cpp/src/transport/TSSLServerSocket.h Tue Feb 22 18:12:06 2011
@@ -0,0 +1,48 @@
+// Copyright (c) 2009- Facebook
+// Distributed under the Thrift Software License
+//
+// See accompanying file LICENSE or visit the Thrift site at:
+// http://developers.facebook.com/thrift/
+
+#ifndef _THRIFT_TRANSPORT_TSSLSERVERSOCKET_H_
+#define _THRIFT_TRANSPORT_TSSLSERVERSOCKET_H_ 1
+
+#include <boost/shared_ptr.hpp>
+#include "TServerSocket.h"
+
+namespace apache { namespace thrift { namespace transport {
+
+class TSSLSocketFactory;
+
+/**
+ * Server socket that accepts SSL connections.
+ *
+ * @author Ping Li <pi...@facebook.com>
+ */
+class TSSLServerSocket: public TServerSocket {
+ public:
+  /**
+   * Constructor.
+   *
+   * @param port    Listening port
+   * @param factory SSL socket factory implementation
+   */
+  TSSLServerSocket(int port, boost::shared_ptr<TSSLSocketFactory> factory);
+  /**
+   * Constructor.
+   *
+   * @param port        Listening port
+   * @param sendTimeout Socket send timeout
+   * @param recvTimeout Socket receive timeout
+   * @param factory     SSL socket factory implementation
+   */
+  TSSLServerSocket(int port, int sendTimeout, int recvTimeout,
+                   boost::shared_ptr<TSSLSocketFactory> factory);
+ protected:
+  boost::shared_ptr<TSocket> createSocket(int socket);
+  boost::shared_ptr<TSSLSocketFactory> factory_;
+};
+
+}}}
+
+#endif

Added: thrift/trunk/lib/cpp/src/transport/TSSLSocket.cpp
URL: http://svn.apache.org/viewvc/thrift/trunk/lib/cpp/src/transport/TSSLSocket.cpp?rev=1073441&view=auto
==============================================================================
--- thrift/trunk/lib/cpp/src/transport/TSSLSocket.cpp (added)
+++ thrift/trunk/lib/cpp/src/transport/TSSLSocket.cpp Tue Feb 22 18:12:06 2011
@@ -0,0 +1,645 @@
+// Copyright (c) 2009- Facebook
+// Distributed under the Thrift Software License
+//
+// See accompanying file LICENSE or visit the Thrift site at:
+// http://developers.facebook.com/thrift/
+
+#include <errno.h>
+#include <string>
+#include <arpa/inet.h>
+#include <boost/lexical_cast.hpp>
+#include <boost/shared_array.hpp>
+#include <openssl/err.h>
+#include <openssl/rand.h>
+#include <openssl/ssl.h>
+#include <openssl/x509v3.h>
+#include "concurrency/Mutex.h"
+#include "TSSLSocket.h"
+
+#define OPENSSL_VERSION_NO_THREAD_ID 0x10000000L
+
+using namespace std;
+using namespace boost;
+using namespace apache::thrift::concurrency;
+
+struct CRYPTO_dynlock_value {
+  Mutex mutex;
+};
+
+namespace apache { namespace thrift { namespace transport {
+
+
+static void buildErrors(string& message, int error = 0);
+static bool matchName(const char* host, const char* pattern, int size);
+static char uppercase(char c);
+
+// SSLContext implementation
+SSLContext::SSLContext() {
+  ctx_ = SSL_CTX_new(TLSv1_method());
+  if (ctx_ == NULL) {
+    string errors;
+    buildErrors(errors);
+    throw TSSLException("SSL_CTX_new: " + errors);
+  }
+  SSL_CTX_set_mode(ctx_, SSL_MODE_AUTO_RETRY);
+}
+
+SSLContext::~SSLContext() {
+  if (ctx_ != NULL) {
+    SSL_CTX_free(ctx_);
+    ctx_ = NULL;
+  }
+}
+
+SSL* SSLContext::createSSL() {
+  SSL* ssl = SSL_new(ctx_);
+  if (ssl == NULL) {
+    string errors;
+    buildErrors(errors);
+    throw TSSLException("SSL_new: " + errors);
+  }
+  return ssl;
+}
+
+// TSSLSocket implementation
+TSSLSocket::TSSLSocket(shared_ptr<SSLContext> ctx):
+  TSocket(), server_(false), ssl_(NULL), ctx_(ctx) {
+}
+
+TSSLSocket::TSSLSocket(shared_ptr<SSLContext> ctx, int socket):
+  TSocket(socket), server_(false), ssl_(NULL), ctx_(ctx) {
+}
+
+TSSLSocket::TSSLSocket(shared_ptr<SSLContext> ctx, string host, int port):
+  TSocket(host, port), server_(false), ssl_(NULL), ctx_(ctx) {
+}
+
+TSSLSocket::~TSSLSocket() {
+  close();
+}
+
+bool TSSLSocket::isOpen() {
+  if (ssl_ == NULL || !TSocket::isOpen()) {
+    return false;
+  }
+  int shutdown = SSL_get_shutdown(ssl_);
+  bool shutdownReceived = (shutdown & SSL_RECEIVED_SHUTDOWN);
+  bool shutdownSent     = (shutdown & SSL_SENT_SHUTDOWN);
+  if (shutdownReceived && shutdownSent) {
+    return false;
+  }
+  return true;
+}
+
+bool TSSLSocket::peek() {
+  if (!isOpen()) {
+    return false;
+  }
+  checkHandshake();
+  int rc;
+  uint8_t byte;
+  rc = SSL_peek(ssl_, &byte, 1);
+  if (rc < 0) {
+    int errno_copy = errno;
+    string errors;
+    buildErrors(errors, errno_copy);
+    throw TSSLException("SSL_peek: " + errors);
+  }
+  if (rc == 0) {
+    ERR_clear_error();
+  }
+  return (rc > 0);
+}
+
+void TSSLSocket::open() {
+  if (isOpen() || server()) {
+    throw TTransportException(TTransportException::BAD_ARGS);
+  }
+  TSocket::open();
+}
+
+void TSSLSocket::close() {
+  if (ssl_ != NULL) {
+    int rc = SSL_shutdown(ssl_);
+    if (rc == 0) {
+      rc = SSL_shutdown(ssl_);
+    }
+    if (rc < 0) {
+      int errno_copy = errno;
+      string errors;
+      buildErrors(errors, errno_copy);
+      GlobalOutput(("SSL_shutdown: " + errors).c_str());
+    }
+    SSL_free(ssl_);
+    ssl_ = NULL;
+    ERR_remove_state(0);
+  }
+  TSocket::close();
+}
+
+uint32_t TSSLSocket::read(uint8_t* buf, uint32_t len) {
+  checkHandshake();
+  int32_t bytes = 0;
+  for (int32_t retries = 0; retries < maxRecvRetries_; retries++){
+    bytes = SSL_read(ssl_, buf, len);
+    if (bytes >= 0)
+      break;
+    int errno_copy = errno;
+    if (SSL_get_error(ssl_, bytes) == SSL_ERROR_SYSCALL) {
+      if (ERR_get_error() == 0 && errno_copy == EINTR) {
+        continue;
+      }
+    }
+    string errors;
+    buildErrors(errors, errno_copy);
+    throw TSSLException("SSL_read: " + errors);
+  }
+  return bytes;
+}
+
+void TSSLSocket::write(const uint8_t* buf, uint32_t len) {
+  checkHandshake();
+  // loop in case SSL_MODE_ENABLE_PARTIAL_WRITE is set in SSL_CTX.
+  uint32_t written = 0;
+  while (written < len) {
+    int32_t bytes = SSL_write(ssl_, &buf[written], len - written);
+    if (bytes <= 0) {
+      int errno_copy = errno;
+      string errors;
+      buildErrors(errors, errno_copy);
+      throw TSSLException("SSL_write: " + errors);
+    }
+    written += bytes;
+  }
+}
+
+void TSSLSocket::flush() {
+  // Don't throw exception if not open. Thrift servers close socket twice.
+  if (ssl_ == NULL) {
+    return;
+  }
+  checkHandshake();
+  BIO* bio = SSL_get_wbio(ssl_);
+  if (bio == NULL) {
+    throw TSSLException("SSL_get_wbio returns NULL");
+  }
+  if (BIO_flush(bio) != 1) {
+    int errno_copy = errno;
+    string errors;
+    buildErrors(errors, errno_copy);
+    throw TSSLException("BIO_flush: " + errors);
+  }
+}
+
+void TSSLSocket::checkHandshake() {
+  if (!TSocket::isOpen()) {
+    throw TTransportException(TTransportException::NOT_OPEN);
+  }
+  if (ssl_ != NULL) {
+    return;
+  }
+  ssl_ = ctx_->createSSL();
+  SSL_set_fd(ssl_, socket_);
+  int rc;
+  if (server()) {
+    rc = SSL_accept(ssl_);
+  } else {
+    rc = SSL_connect(ssl_);
+  }
+  if (rc <= 0) {
+    int errno_copy = errno;
+    string fname(server() ? "SSL_accept" : "SSL_connect");
+    string errors;
+    buildErrors(errors, errno_copy);
+    throw TSSLException(fname + ": " + errors);
+  }
+  authorize();
+}
+
+void TSSLSocket::authorize() {
+  int rc = SSL_get_verify_result(ssl_);
+  if (rc != X509_V_OK) {  // verify authentication result
+    throw TSSLException(string("SSL_get_verify_result(), ") +
+                        X509_verify_cert_error_string(rc));
+  }
+
+  X509* cert = SSL_get_peer_certificate(ssl_);
+  if (cert == NULL) {
+    // certificate is not present
+    if (SSL_get_verify_mode(ssl_) & SSL_VERIFY_FAIL_IF_NO_PEER_CERT) {
+      throw TSSLException("authorize: required certificate not present");
+    }
+    // certificate was optional: didn't intend to authorize remote
+    if (server() && access_ != NULL) {
+      throw TSSLException("authorize: certificate required for authorization");
+    }
+    return;
+  }
+  // certificate is present
+  if (access_ == NULL) {
+    X509_free(cert);
+    return;
+  }
+  // both certificate and access manager are present
+
+  string host;
+  sockaddr_storage sa = {};
+  socklen_t saLength = sizeof(sa);
+
+  if (getpeername(socket_, (sockaddr*)&sa, &saLength) != 0) {
+    sa.ss_family = AF_UNSPEC;
+  }
+
+  AccessManager::Decision decision = access_->verify(sa);
+
+  if (decision != AccessManager::SKIP) {
+    X509_free(cert);
+    if (decision != AccessManager::ALLOW) {
+      throw TSSLException("authorize: access denied based on remote IP");
+    }
+    return;
+  }
+
+  // extract subjectAlternativeName
+  STACK_OF(GENERAL_NAME)* alternatives = (STACK_OF(GENERAL_NAME)*)
+                       X509_get_ext_d2i(cert, NID_subject_alt_name, NULL, NULL);
+  if (alternatives != NULL) {
+    const int count = sk_GENERAL_NAME_num(alternatives);
+    for (int i = 0; decision == AccessManager::SKIP && i < count; i++) {
+      const GENERAL_NAME* name = sk_GENERAL_NAME_value(alternatives, i);
+      if (name == NULL) {
+        continue;
+      }
+      char* data = (char*)ASN1_STRING_data(name->d.ia5);
+      int length = ASN1_STRING_length(name->d.ia5);
+      switch (name->type) {
+        case GEN_DNS:
+          if (host.empty()) {
+            host = (server() ? getPeerHost() : getHost());
+          }
+          decision = access_->verify(host, data, length);
+          break;
+        case GEN_IPADD:
+          decision = access_->verify(sa, data, length);
+          break;
+      }
+    }
+    sk_GENERAL_NAME_pop_free(alternatives, GENERAL_NAME_free);
+  }
+
+  if (decision != AccessManager::SKIP) {
+    X509_free(cert);
+    if (decision != AccessManager::ALLOW) {
+      throw TSSLException("authorize: access denied");
+    }
+    return;
+  }
+
+  // extract commonName
+  X509_NAME* name = X509_get_subject_name(cert);
+  if (name != NULL) {
+    X509_NAME_ENTRY* entry;
+    unsigned char* utf8;
+    int last = -1;
+    while (decision == AccessManager::SKIP) {
+      last = X509_NAME_get_index_by_NID(name, NID_commonName, last);
+      if (last == -1)
+        break;
+      entry = X509_NAME_get_entry(name, last);
+      if (entry == NULL)
+        continue;
+      ASN1_STRING* common = X509_NAME_ENTRY_get_data(entry);
+      int size = ASN1_STRING_to_UTF8(&utf8, common);
+      if (host.empty()) {
+        host = (server() ? getHost() : getHost());
+      }
+      decision = access_->verify(host, (char*)utf8, size);
+      OPENSSL_free(utf8);
+    }
+  }
+  X509_free(cert);
+  if (decision != AccessManager::ALLOW) {
+    throw TSSLException("authorize: cannot authorize peer");
+  }
+}
+
+// TSSLSocketFactory implementation
+bool     TSSLSocketFactory::initialized = false;
+uint64_t TSSLSocketFactory::count_ = 0;
+Mutex    TSSLSocketFactory::mutex_;
+
+TSSLSocketFactory::TSSLSocketFactory(): server_(false) {
+  Guard guard(mutex_);
+  if (count_ == 0) {
+    initializeOpenSSL();
+    randomize();
+  }
+  count_++;
+  ctx_ = shared_ptr<SSLContext>(new SSLContext);
+}
+
+TSSLSocketFactory::~TSSLSocketFactory() {
+  Guard guard(mutex_);
+  count_--;
+  if (count_ == 0) {
+    cleanupOpenSSL();
+  }
+}
+
+shared_ptr<TSSLSocket> TSSLSocketFactory::createSocket() {
+  shared_ptr<TSSLSocket> ssl(new TSSLSocket(ctx_));
+  setup(ssl);
+  return ssl;
+}
+
+shared_ptr<TSSLSocket> TSSLSocketFactory::createSocket(int socket) {
+  shared_ptr<TSSLSocket> ssl(new TSSLSocket(ctx_, socket));
+  setup(ssl);
+  return ssl;
+}
+
+shared_ptr<TSSLSocket> TSSLSocketFactory::createSocket(const string& host,
+                                                       int port) {
+  shared_ptr<TSSLSocket> ssl(new TSSLSocket(ctx_, host, port));
+  setup(ssl);
+  return ssl;
+}
+
+void TSSLSocketFactory::setup(shared_ptr<TSSLSocket> ssl) {
+  ssl->server(server());
+  if (access_ == NULL && !server()) {
+    access_ = shared_ptr<AccessManager>(new DefaultClientAccessManager);
+  }
+  if (access_ != NULL) {
+    ssl->access(access_);
+  }
+}
+
+void TSSLSocketFactory::ciphers(const string& enable) {
+  int rc = SSL_CTX_set_cipher_list(ctx_->get(), enable.c_str());
+  if (ERR_peek_error() != 0) {
+    string errors;
+    buildErrors(errors);
+    throw TSSLException("SSL_CTX_set_cipher_list: " + errors);
+  }
+  if (rc == 0) {
+    throw TSSLException("None of specified ciphers are supported");
+  }
+}
+
+void TSSLSocketFactory::authenticate(bool required) {
+  int mode;
+  if (required) {
+    mode  = SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT | SSL_VERIFY_CLIENT_ONCE;
+  } else {
+    mode = SSL_VERIFY_NONE;
+  }
+  SSL_CTX_set_verify(ctx_->get(), mode, NULL);
+}
+
+void TSSLSocketFactory::loadCertificate(const char* path, const char* format) {
+  if (path == NULL || format == NULL) {
+    throw TTransportException(TTransportException::BAD_ARGS,
+         "loadCertificateChain: either <path> or <format> is NULL");
+  }
+  if (strcmp(format, "PEM") == 0) {
+    if (SSL_CTX_use_certificate_chain_file(ctx_->get(), path) == 0) {
+      int errno_copy = errno;
+      string errors;
+      buildErrors(errors, errno_copy);
+      throw TSSLException("SSL_CTX_use_certificate_chain_file: " + errors);
+    }
+  } else {
+    throw TSSLException("Unsupported certificate format: " + string(format));
+  }
+}
+
+void TSSLSocketFactory::loadPrivateKey(const char* path, const char* format) {
+  if (path == NULL || format == NULL) {
+    throw TTransportException(TTransportException::BAD_ARGS,
+         "loadPrivateKey: either <path> or <format> is NULL");
+  }
+  if (strcmp(format, "PEM") == 0) {
+    if (SSL_CTX_use_PrivateKey_file(ctx_->get(), path, SSL_FILETYPE_PEM) == 0) {
+      int errno_copy = errno;
+      string errors;
+      buildErrors(errors, errno_copy);
+      throw TSSLException("SSL_CTX_use_PrivateKey_file: " + errors);
+    }
+  }
+}
+
+void TSSLSocketFactory::loadTrustedCertificates(const char* path) {
+  if (path == NULL) {
+    throw TTransportException(TTransportException::BAD_ARGS,
+         "loadTrustedCertificates: <path> is NULL");
+  }
+  if (SSL_CTX_load_verify_locations(ctx_->get(), path, NULL) == 0) {
+    int errno_copy = errno;
+    string errors;
+    buildErrors(errors, errno_copy);
+    throw TSSLException("SSL_CTX_load_verify_locations: " + errors);
+  }
+}
+
+void TSSLSocketFactory::randomize() {
+  RAND_poll();
+}
+
+void TSSLSocketFactory::overrideDefaultPasswordCallback() {
+  SSL_CTX_set_default_passwd_cb(ctx_->get(), passwordCallback);
+  SSL_CTX_set_default_passwd_cb_userdata(ctx_->get(), this);
+}
+
+int TSSLSocketFactory::passwordCallback(char* password,
+                                        int size,
+                                        int,
+                                        void* data) {
+  TSSLSocketFactory* factory = (TSSLSocketFactory*)data;
+  string userPassword;
+  factory->getPassword(userPassword, size);
+  int length = userPassword.size();
+  if (length > size) {
+    length = size;
+  }
+  strncpy(password, userPassword.c_str(), length);
+  return length;
+}
+
+static shared_array<Mutex> mutexes;
+
+static void callbackLocking(int mode, int n, const char*, int) {
+  if (mode & CRYPTO_LOCK) {
+    mutexes[n].lock();
+  } else {
+    mutexes[n].unlock();
+  }
+}
+
+#if (OPENSSL_VERSION_NUMBER < OPENSSL_VERSION_NO_THREAD_ID)
+static unsigned long callbackThreadID() {
+  return reinterpret_cast<unsigned long>(pthread_self());
+}
+#endif
+
+static CRYPTO_dynlock_value* dyn_create(const char*, int) {
+  return new CRYPTO_dynlock_value;
+}
+
+static void dyn_lock(int mode,
+                     struct CRYPTO_dynlock_value* lock,
+                     const char*, int) {
+  if (lock != NULL) {
+    if (mode & CRYPTO_LOCK) {
+      lock->mutex.lock();
+    } else {
+      lock->mutex.unlock();
+    }
+  }
+}
+
+static void dyn_destroy(struct CRYPTO_dynlock_value* lock, const char*, int) {
+  delete lock;
+}
+
+void TSSLSocketFactory::initializeOpenSSL() {
+  if (initialized) {
+    return;
+  }
+  initialized = true;
+  SSL_library_init();
+  SSL_load_error_strings();
+  // static locking
+  mutexes = shared_array<Mutex>(new Mutex[::CRYPTO_num_locks()]);
+  if (mutexes == NULL) {
+    throw TTransportException(TTransportException::INTERNAL_ERROR,
+          "initializeOpenSSL() failed, "
+          "out of memory while creating mutex array");
+  }
+#if (OPENSSL_VERSION_NUMBER < OPENSSL_VERSION_NO_THREAD_ID)
+  CRYPTO_set_id_callback(callbackThreadID);
+#endif
+  CRYPTO_set_locking_callback(callbackLocking);
+  // dynamic locking
+  CRYPTO_set_dynlock_create_callback(dyn_create);
+  CRYPTO_set_dynlock_lock_callback(dyn_lock);
+  CRYPTO_set_dynlock_destroy_callback(dyn_destroy);
+}
+
+void TSSLSocketFactory::cleanupOpenSSL() {
+  if (!initialized) {
+    return;
+  }
+  initialized = false;
+#if (OPENSSL_VERSION_NUMBER < OPENSSL_VERSION_NO_THREAD_ID)
+  CRYPTO_set_id_callback(NULL);
+#endif
+  CRYPTO_set_locking_callback(NULL);
+  CRYPTO_set_dynlock_create_callback(NULL);
+  CRYPTO_set_dynlock_lock_callback(NULL);
+  CRYPTO_set_dynlock_destroy_callback(NULL);
+  CRYPTO_cleanup_all_ex_data();
+  ERR_free_strings();
+  EVP_cleanup();
+  ERR_remove_state(0);
+  mutexes.reset();
+}
+
+// extract error messages from error queue
+void buildErrors(string& errors, int errno_copy) {
+  unsigned long  errorCode;
+  char   message[256];
+
+  errors.reserve(512);
+  while ((errorCode = ERR_get_error()) != 0) {
+    if (!errors.empty()) {
+      errors += "; ";
+    }
+    const char* reason = ERR_reason_error_string(errorCode);
+    if (reason == NULL) {
+      snprintf(message, sizeof(message) - 1, "SSL error # %lu", errorCode);
+      reason = message;
+    }
+    errors += reason;
+  }
+  if (errors.empty()) {
+    if (errno_copy != 0) {
+      errors += TOutput::strerror_s(errno_copy);
+    }
+  }
+  if (errors.empty()) {
+    errors = "error code: " + lexical_cast<string>(errno_copy);
+  }
+}
+
+/**
+ * Default implementation of AccessManager
+ */
+Decision DefaultClientAccessManager::verify(const sockaddr_storage& sa)
+  throw() { return SKIP; }
+
+Decision DefaultClientAccessManager::verify(const string& host,
+                                            const char* name,
+                                            int size) throw() {
+  if (host.empty() || name == NULL || size <= 0) {
+    return SKIP;
+  }
+  return (matchName(host.c_str(), name, size) ? ALLOW : SKIP);
+}
+
+Decision DefaultClientAccessManager::verify(const sockaddr_storage& sa,
+                                            const char* data,
+                                            int size) throw() {
+  bool match = false;
+  if (sa.ss_family == AF_INET && size == sizeof(in_addr)) {
+    match = (memcmp(&((sockaddr_in*)&sa)->sin_addr, data, size) == 0);
+  } else if (sa.ss_family == AF_INET6 && size == sizeof(in6_addr)) {
+    match = (memcmp(&((sockaddr_in6*)&sa)->sin6_addr, data, size) == 0);
+  }
+  return (match ? ALLOW : SKIP);
+}
+
+/**
+ * Match a name with a pattern. The pattern may include wildcard. A single
+ * wildcard "*" can match up to one component in the domain name.
+ *
+ * @param  host    Host name, typically the name of the remote host
+ * @param  pattern Name retrieved from certificate
+ * @param  size    Size of "pattern"
+ * @return True, if "host" matches "pattern". False otherwise.
+ */
+bool matchName(const char* host, const char* pattern, int size) {
+  bool match = false;
+  int i = 0, j = 0;
+  while (i < size && host[j] != '\0') {
+    if (uppercase(pattern[i]) == uppercase(host[j])) {
+      i++;
+      j++;
+      continue;
+    }
+    if (pattern[i] == '*') {
+      while (host[j] != '.' && host[j] != '\0') {
+        j++;
+      }
+      i++;
+      continue;
+    }
+    break;
+  }
+  if (i == size && host[j] == '\0') {
+    match = true;
+  }
+  return match;
+
+}
+
+// This is to work around the Turkish locale issue, i.e.,
+// toupper('i') != toupper('I') if locale is "tr_TR"
+char uppercase (char c) {
+  if ('a' <= c && c <= 'z') {
+    return c + ('A' - 'a');
+  }
+  return c;
+}
+
+}}}

Added: thrift/trunk/lib/cpp/src/transport/TSSLSocket.h
URL: http://svn.apache.org/viewvc/thrift/trunk/lib/cpp/src/transport/TSSLSocket.h?rev=1073441&view=auto
==============================================================================
--- thrift/trunk/lib/cpp/src/transport/TSSLSocket.h (added)
+++ thrift/trunk/lib/cpp/src/transport/TSSLSocket.h Tue Feb 22 18:12:06 2011
@@ -0,0 +1,304 @@
+// Copyright (c) 2009- Facebook
+// Distributed under the Thrift Software License
+//
+// See accompanying file LICENSE or visit the Thrift site at:
+// http://developers.facebook.com/thrift/
+
+#ifndef _THRIFT_TRANSPORT_TSSLSOCKET_H_
+#define _THRIFT_TRANSPORT_TSSLSOCKET_H_ 1
+
+#include <string>
+#include <boost/shared_ptr.hpp>
+#include <openssl/ssl.h>
+#include "concurrency/Mutex.h"
+#include "TSocket.h"
+
+namespace apache { namespace thrift { namespace transport {
+
+class AccessManager;
+class SSLContext;
+
+/**
+ * OpenSSL implementation for SSL socket interface.
+ *
+ * @author Ping Li <pi...@facebook.com>
+ */
+class TSSLSocket: public TSocket {
+ public:
+ ~TSSLSocket();
+  /**
+   * TTransport interface.
+   */
+  bool     isOpen();
+  bool     peek();
+  void     open();
+  void     close();
+  uint32_t read(uint8_t* buf, uint32_t len);
+  void     write(const uint8_t* buf, uint32_t len);
+  void     flush();
+   /**
+   * Set whether to use client or server side SSL handshake protocol.
+   *
+   * @param flag  Use server side handshake protocol if true.
+   */
+  void server(bool flag) { server_ = flag; }
+  /**
+   * Determine whether the SSL socket is server or client mode.
+   */
+  bool server() const { return server_; }
+  /**
+   * Set AccessManager.
+   *
+   * @param manager  Instance of AccessManager
+   */
+  virtual void access(boost::shared_ptr<AccessManager> manager) {
+    access_ = manager;
+  }
+protected:
+  /**
+   * Constructor.
+   */
+  TSSLSocket(boost::shared_ptr<SSLContext> ctx);
+  /**
+   * Constructor, create an instance of TSSLSocket given an existing socket.
+   *
+   * @param socket An existing socket
+   */
+  TSSLSocket(boost::shared_ptr<SSLContext> ctx, int socket);
+  /**
+   * Constructor.
+   *
+   * @param host  Remote host name
+   * @param port  Remote port number
+   */
+  TSSLSocket(boost::shared_ptr<SSLContext> ctx,
+                               std::string host,
+                                       int port);
+  /**
+   * Authorize peer access after SSL handshake completes.
+   */
+  virtual void authorize();
+  /**
+   * Initiate SSL handshake if not already initiated.
+   */
+  void checkHandshake();
+
+  bool server_;
+  SSL* ssl_;
+  boost::shared_ptr<SSLContext> ctx_;
+  boost::shared_ptr<AccessManager> access_;
+  friend class TSSLSocketFactory;
+};
+
+/**
+ * SSL socket factory. SSL sockets should be created via SSL factory.
+ */
+class TSSLSocketFactory {
+ public:
+  /**
+   * Constructor/Destructor
+   */
+  TSSLSocketFactory();
+  virtual ~TSSLSocketFactory();
+  /**
+   * Create an instance of TSSLSocket with a fresh new socket.
+   */
+  virtual boost::shared_ptr<TSSLSocket> createSocket();
+  /**
+   * Create an instance of TSSLSocket with the given socket.
+   *
+   * @param socket An existing socket.
+   */
+  virtual boost::shared_ptr<TSSLSocket> createSocket(int socket);
+   /**
+   * Create an instance of TSSLSocket.
+   *
+   * @param host  Remote host to be connected to
+   * @param port  Remote port to be connected to
+   */
+  virtual boost::shared_ptr<TSSLSocket> createSocket(const std::string& host,
+                                                     int port);
+  /**
+   * Set ciphers to be used in SSL handshake process.
+   *
+   * @param ciphers  A list of ciphers
+   */
+  virtual void ciphers(const std::string& enable);
+  /**
+   * Enable/Disable authentication.
+   *
+   * @param required Require peer to present valid certificate if true
+   */
+  virtual void authenticate(bool required);
+  /**
+   * Load server certificate.
+   *
+   * @param path   Path to the certificate file
+   * @param format Certificate file format
+   */
+  virtual void loadCertificate(const char* path, const char* format = "PEM");
+  /**
+   * Load private key.
+   *
+   * @param path   Path to the private key file
+   * @param format Private key file format
+   */
+  virtual void loadPrivateKey(const char* path, const char* format = "PEM");
+  /**
+   * Load trusted certificates from specified file.
+   *
+   * @param path Path to trusted certificate file
+   */
+  virtual void loadTrustedCertificates(const char* path);
+  /**
+   * Default randomize method.
+   */
+  virtual void randomize();
+  /**
+   * Override default OpenSSL password callback with getPassword().
+   */
+  void overrideDefaultPasswordCallback();
+  /**
+   * Set/Unset server mode.
+   *
+   * @param flag  Server mode if true
+   */
+  virtual void server(bool flag) { server_ = flag; }
+  /**
+   * Determine whether the socket is in server or client mode.
+   *
+   * @return true, if server mode, or, false, if client mode
+   */
+  virtual bool server() const { return server_; }
+  /**
+   * Set AccessManager.
+   *
+   * @param manager  The AccessManager instance
+   */
+  virtual void access(boost::shared_ptr<AccessManager> manager) {
+    access_ = manager;
+  }
+ protected:
+  boost::shared_ptr<SSLContext> ctx_;
+
+  static void initializeOpenSSL();
+  static void cleanupOpenSSL();
+  /**
+   * Override this method for custom password callback. It may be called
+   * multiple times at any time during a session as necessary.
+   *
+   * @param password Pass collected password to OpenSSL
+   * @param size     Maximum length of password including NULL character
+   */
+  virtual void getPassword(std::string& password, int size) { }
+ private:
+  bool server_;
+  boost::shared_ptr<AccessManager> access_;
+  static bool initialized;
+  static concurrency::Mutex mutex_;
+  static uint64_t count_;
+  void setup(boost::shared_ptr<TSSLSocket> ssl);
+  static int passwordCallback(char* password, int size, int, void* data);
+};
+
+/**
+ * SSL exception.
+ */
+class TSSLException: public TTransportException {
+ public:
+  TSSLException(const std::string& message):
+    TTransportException(TTransportException::INTERNAL_ERROR, message) {}
+
+  virtual const char* what() const throw() {
+    if (message_.empty()) {
+      return "TSSLException";
+    } else {
+      return message_.c_str();
+    }
+  }
+};
+
+/**
+ * Wrap OpenSSL SSL_CTX into a class.
+ */
+class SSLContext {
+ public:
+  SSLContext();
+  virtual ~SSLContext();
+  SSL* createSSL();
+  SSL_CTX* get() { return ctx_; }
+ private:
+  SSL_CTX* ctx_;
+};
+
+/**
+ * Callback interface for access control. It's meant to verify the remote host.
+ * It's constructed when application starts and set to TSSLSocketFactory
+ * instance. It's passed onto all TSSLSocket instances created by this factory
+ * object.
+ */
+class AccessManager {
+ public:
+  enum Decision {
+    DENY   = -1,    // deny access
+    SKIP   =  0,    // cannot make decision, move on to next (if any)
+    ALLOW  =  1,    // allow access
+  };
+ /**
+  * Destructor
+  */
+ virtual ~AccessManager() {}
+ /**
+  * Determine whether the peer should be granted access or not. It's called
+  * once after the SSL handshake completes successfully, before peer certificate
+  * is examined.
+  *
+  * If a valid decision (ALLOW or DENY) is returned, the peer certificate is
+  * not to be verified.
+  *
+  * @param  sa Peer IP address
+  * @return True if the peer is trusted, false otherwise
+  */
+ virtual Decision verify(const sockaddr_storage& sa) throw() { return DENY; }
+ /**
+  * Determine whether the peer should be granted access or not. It's called
+  * every time a DNS subjectAltName/common name is extracted from peer's
+  * certificate.
+  *
+  * @param  host Client mode: host name returned by TSocket::getHost()
+  *              Server mode: host name returned by TSocket::getPeerHost()
+  * @param  name SubjectAltName or common name extracted from peer certificate
+  * @param  size Length of name
+  * @return True if the peer is trusted, false otherwise
+  *
+  * Note: The "name" parameter may be UTF8 encoded.
+  */
+ virtual Decision verify(const std::string& host, const char* name, int size)
+   throw() { return DENY; }
+ /**
+  * Determine whether the peer should be granted access or not. It's called
+  * every time an IP subjectAltName is extracted from peer's certificate.
+  *
+  * @param  sa   Peer IP address retrieved from the underlying socket
+  * @param  data IP address extracted from certificate
+  * @param  size Length of the IP address
+  * @return True if the peer is trusted, false otherwise
+  */
+ virtual Decision verify(const sockaddr_storage& sa, const char* data, int size)
+   throw() { return DENY; }
+};
+
+typedef AccessManager::Decision Decision;
+
+class DefaultClientAccessManager: public AccessManager {
+ public:
+  // AccessManager interface
+  Decision verify(const sockaddr_storage& sa) throw();
+  Decision verify(const std::string& host, const char* name, int size) throw();
+  Decision verify(const sockaddr_storage& sa, const char* data, int size) throw();
+};
+
+
+}}}
+
+#endif

Modified: thrift/trunk/lib/cpp/src/transport/TServerSocket.cpp
URL: http://svn.apache.org/viewvc/thrift/trunk/lib/cpp/src/transport/TServerSocket.cpp?rev=1073441&r1=1073440&r2=1073441&view=diff
==============================================================================
--- thrift/trunk/lib/cpp/src/transport/TServerSocket.cpp (original)
+++ thrift/trunk/lib/cpp/src/transport/TServerSocket.cpp Tue Feb 22 18:12:06 2011
@@ -386,7 +386,7 @@ shared_ptr<TTransport> TServerSocket::ac
     throw TTransportException(TTransportException::UNKNOWN, "fcntl(F_SETFL)", errno_copy);
   }
 
-  shared_ptr<TSocket> client(new TSocket(clientSocket));
+  shared_ptr<TSocket> client = createSocket(clientSocket);
   if (sendTimeout_ > 0) {
     client->setSendTimeout(sendTimeout_);
   }
@@ -398,6 +398,10 @@ shared_ptr<TTransport> TServerSocket::ac
   return client;
 }
 
+shared_ptr<TSocket> TServerSocket::createSocket(int clientSocket) {
+  return shared_ptr<TSocket>(new TSocket(clientSocket));
+}
+
 void TServerSocket::interrupt() {
   if (intSock1_ >= 0) {
     int8_t byte = 0;

Modified: thrift/trunk/lib/cpp/src/transport/TServerSocket.h
URL: http://svn.apache.org/viewvc/thrift/trunk/lib/cpp/src/transport/TServerSocket.h?rev=1073441&r1=1073440&r2=1073441&view=diff
==============================================================================
--- thrift/trunk/lib/cpp/src/transport/TServerSocket.h (original)
+++ thrift/trunk/lib/cpp/src/transport/TServerSocket.h Tue Feb 22 18:12:06 2011
@@ -56,6 +56,7 @@ class TServerSocket : public TServerTran
 
  protected:
   boost::shared_ptr<TTransport> acceptImpl();
+  virtual boost::shared_ptr<TSocket> createSocket(int client);
 
  private:
   int port_;

Modified: thrift/trunk/lib/cpp/src/transport/TSocket.h
URL: http://svn.apache.org/viewvc/thrift/trunk/lib/cpp/src/transport/TSocket.h?rev=1073441&r1=1073440&r2=1073441&view=diff
==============================================================================
--- thrift/trunk/lib/cpp/src/transport/TSocket.h (original)
+++ thrift/trunk/lib/cpp/src/transport/TSocket.h Tue Feb 22 18:12:06 2011
@@ -70,12 +70,12 @@ class TSocket : public TVirtualTransport
    *
    * @return Is the socket alive?
    */
-  bool isOpen();
+  virtual bool isOpen();
 
   /**
    * Calls select on the socket to see if there is more data available.
    */
-  bool peek();
+  virtual bool peek();
 
   /**
    * Creates and opens the UNIX socket.
@@ -92,12 +92,12 @@ class TSocket : public TVirtualTransport
   /**
    * Reads from the underlying socket.
    */
-  uint32_t read(uint8_t* buf, uint32_t len);
+  virtual uint32_t read(uint8_t* buf, uint32_t len);
 
   /**
    * Writes to the underlying socket.  Loops until done or fail.
    */
-  void write(const uint8_t* buf, uint32_t len);
+  virtual void write(const uint8_t* buf, uint32_t len);
 
   /**
    * Writes to the underlying socket.  Does single send() and returns result.

Modified: thrift/trunk/test/cpp/src/TestClient.cpp
URL: http://svn.apache.org/viewvc/thrift/trunk/test/cpp/src/TestClient.cpp?rev=1073441&r1=1073440&r2=1073441&view=diff
==============================================================================
--- thrift/trunk/test/cpp/src/TestClient.cpp (original)
+++ thrift/trunk/test/cpp/src/TestClient.cpp Tue Feb 22 18:12:06 2011
@@ -23,6 +23,7 @@
 #include <protocol/TBinaryProtocol.h>
 #include <transport/TTransportUtils.h>
 #include <transport/TSocket.h>
+#include <transport/TSSLSocket.h>
 
 #include <boost/shared_ptr.hpp>
 #include "ThriftTest.h"
@@ -56,6 +57,7 @@ int main(int argc, char** argv) {
   int port = 9090;
   int numTests = 1;
   bool framed = false;
+  bool ssl = false;
 
   for (int i = 0; i < argc; ++i) {
     if (strcmp(argv[i], "-h") == 0) {
@@ -71,9 +73,22 @@ int main(int argc, char** argv) {
       numTests = atoi(argv[++i]);
     } else if (strcmp(argv[i], "-f") == 0) {
       framed = true;
+    } else if (strcmp(argv[i], "--ssl") == 0) {
+      ssl = true;
     }
   }
 
+  shared_ptr<TSocket> socket;
+  shared_ptr<TSSLSocketFactory> factory;
+  if (ssl) {
+    factory = shared_ptr<TSSLSocketFactory>(new TSSLSocketFactory());
+    factory->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
+    factory->loadTrustedCertificates("./trusted-ca-certificate.pem");
+    factory->authenticate(true);
+    socket = factory->createSocket(host, port);
+  } else {
+    socket = shared_ptr<TSocket>(new TSocket(host, port));
+  }
 
   shared_ptr<TBufferBase> transport;
 

Modified: thrift/trunk/test/cpp/src/TestServer.cpp
URL: http://svn.apache.org/viewvc/thrift/trunk/test/cpp/src/TestServer.cpp?rev=1073441&r1=1073440&r2=1073441&view=diff
==============================================================================
--- thrift/trunk/test/cpp/src/TestServer.cpp (original)
+++ thrift/trunk/test/cpp/src/TestServer.cpp Tue Feb 22 18:12:06 2011
@@ -34,6 +34,7 @@
 
 #define __STDC_FORMAT_MACROS
 #include <inttypes.h>
+#include <signal.h>
 
 using namespace std;
 using namespace boost;
@@ -326,6 +327,7 @@ int main(int argc, char **argv) {
   string serverType = "simple";
   string protocolType = "binary";
   size_t workerCount = 4;
+  bool ssl = false;
 
   ostringstream usage;
 
@@ -391,6 +393,11 @@ int main(int argc, char **argv) {
     cerr << usage;
   }
 
+  if (args["ssl"] == "true") {
+    ssl = true;
+    signal(SIGPIPE, SIG_IGN);
+  }
+
   // Dispatcher
   shared_ptr<TProtocolFactory> protocolFactory(
       new TBinaryProtocolFactoryT<TBufferBase>());
@@ -407,8 +414,18 @@ int main(int argc, char **argv) {
   }
 
   // Transport
-  shared_ptr<TServerSocket> serverSocket(new TServerSocket(port));
+  shared_ptr<TSSLSocketFactory> sslSocketFactory;
+  shared_ptr<TServerSocket> serverSocket;
 
+  if (ssl) {
+    sslSocketFactory = shared_ptr<TSSLSocketFactory>(new TSSLSocketFactory());
+    sslSocketFactory->loadCertificate("./server-certificate.pem");
+    sslSocketFactory->loadPrivateKey("./server-private-key.pem");
+    sslSocketFactory->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
+    serverSocket = shared_ptr<TServerSocket>(new TSSLServerSocket(port, sslSocketFactory));
+  } else {
+    serverSocket = shared_ptr<TServerSocket>(new TServerSocket(port));
+  }
   // Factory
   shared_ptr<TTransportFactory> transportFactory(new TBufferedTransportFactory());