You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mesos.apache.org by be...@apache.org on 2019/07/05 12:08:37 UTC

[mesos] 08/12: Added ability to pass custom SSL context to `Socket::connect()`.

This is an automated email from the ASF dual-hosted git repository.

bennoe pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/mesos.git

commit ec129665a346f86c738522536f89de7c519f3e0d
Author: Benno Evers <be...@mesosphere.com>
AuthorDate: Fri Jun 28 20:12:44 2019 +0200

    Added ability to pass custom SSL context to `Socket::connect()`.
    
    Users of libprocess can now pass a custom SSL context when
    connecting a generic socket via the `Socket::connect()`
    function.
    
    Additionally the API of `Socket::connect()` was also reworked
    according to the following boundary conditions requested by
    libprocess maintainers:
    
     * When libprocess is compiled without SSL support, neither the
       declaration of the TLS configuration object nor the `connnect()`
       overload that accepts the TLS configuration should be available.
     * Passing just the servername is not an acceptable short-hand for
       using the default TLS configuration together with that servername.
     * When the incorrect overload is selected (i.e. passing TLS config
       to a poll socket or omitting TLS configuration for a TLS socket),
       the program should abort.
    
    This following changes are introduced according to the requirements
    above:
    
     * A new class `openssl::TLSClientConfig` is introduced when libprocess
       is compiled with ssl support.
     * A new overload
       `Socket::connect(const Address&, const TLSClientConfig&)` is
       introduced when libprocess is compiled with ssl support.
     * All call sites are adjusted to check the socket kind before calling
       `connect()`.
    
    Review: https://reviews.apache.org/r/70991
---
 3rdparty/libprocess/include/Makefile.am            |   1 +
 3rdparty/libprocess/include/process/socket.hpp     |  21 ++-
 .../libprocess/include/process/ssl/tls_config.hpp  |  89 +++++++++++
 3rdparty/libprocess/src/http.cpp                   |  18 ++-
 3rdparty/libprocess/src/openssl.cpp                |  51 +++++-
 3rdparty/libprocess/src/openssl.hpp                |   2 +
 3rdparty/libprocess/src/poll_socket.hpp            |   8 +-
 .../src/posix/libevent/libevent_ssl_socket.cpp     |  68 ++++----
 .../src/posix/libevent/libevent_ssl_socket.hpp     |   4 +-
 3rdparty/libprocess/src/posix/poll_socket.cpp      |  11 +-
 3rdparty/libprocess/src/process.cpp                |  27 +++-
 3rdparty/libprocess/src/tests/http_tests.cpp       |  16 +-
 3rdparty/libprocess/src/tests/socket_tests.cpp     |  34 +++-
 3rdparty/libprocess/src/tests/ssl_client.cpp       |  14 +-
 3rdparty/libprocess/src/tests/ssl_tests.cpp        | 172 ++++++++++++++++++++-
 15 files changed, 483 insertions(+), 53 deletions(-)

diff --git a/3rdparty/libprocess/include/Makefile.am b/3rdparty/libprocess/include/Makefile.am
index 1ddcc2d..e1a6f1e 100644
--- a/3rdparty/libprocess/include/Makefile.am
+++ b/3rdparty/libprocess/include/Makefile.am
@@ -71,6 +71,7 @@ nobase_include_HEADERS =		\
   process/subprocess.hpp		\
   process/ssl/flags.hpp			\
   process/ssl/gtest.hpp			\
+  process/ssl/tls_config.hpp		\
   process/ssl/utilities.hpp		\
   process/time.hpp			\
   process/timeout.hpp			\
diff --git a/3rdparty/libprocess/include/process/socket.hpp b/3rdparty/libprocess/include/process/socket.hpp
index 88f6486..48860f8 100644
--- a/3rdparty/libprocess/include/process/socket.hpp
+++ b/3rdparty/libprocess/include/process/socket.hpp
@@ -23,6 +23,8 @@
 #include <process/address.hpp>
 #include <process/future.hpp>
 
+#include <process/ssl/tls_config.hpp>
+
 #include <stout/abort.hpp>
 #include <stout/error.hpp>
 #include <stout/nothing.hpp>
@@ -150,8 +152,13 @@ public:
   virtual Future<std::shared_ptr<SocketImpl>> accept() = 0;
 
   virtual Future<Nothing> connect(
+      const Address& address) = 0;
+
+#ifdef USE_SSL_SOCKET
+  virtual Future<Nothing> connect(
       const Address& address,
-      const Option<std::string>& peer_hostname) = 0;
+      const openssl::TLSClientConfig& config) = 0;
+#endif
 
   virtual Future<size_t> recv(char* data, size_t size) = 0;
   virtual Future<size_t> send(const char* data, size_t size) = 0;
@@ -361,17 +368,23 @@ public:
       });
   }
 
+  // NOTE: Calling this overload when `kind() == SSL` will result
+  // in program termination.
   Future<Nothing> connect(const AddressType& address)
   {
-    return impl->connect(address, None());
+    return impl->connect(address);
   }
 
+#ifdef USE_SSL_SOCKET
+  // NOTE: Calling this overload when `kind() == POLL` will result
+  // in program termination.
   Future<Nothing> connect(
       const AddressType& address,
-      const Option<std::string>& peer_hostname)
+      const openssl::TLSClientConfig& config)
   {
-    return impl->connect(address, peer_hostname);
+    return impl->connect(address, config);
   }
+#endif
 
   Future<size_t> recv(char* data, size_t size) const
   {
diff --git a/3rdparty/libprocess/include/process/ssl/tls_config.hpp b/3rdparty/libprocess/include/process/ssl/tls_config.hpp
new file mode 100644
index 0000000..18c51a8
--- /dev/null
+++ b/3rdparty/libprocess/include/process/ssl/tls_config.hpp
@@ -0,0 +1,89 @@
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License
+
+#ifndef __PROCESS_SSL_TLS_CONFIG_HPP__
+#define __PROCESS_SSL_TLS_CONFIG_HPP__
+
+#ifdef USE_SSL_SOCKET
+
+#include <openssl/ssl.h>
+
+#include <stout/option.hpp>
+
+namespace process {
+namespace network {
+namespace openssl {
+
+struct TLSClientConfig {
+  // Callback that will be called before the TLS handshake is started.
+  typedef Try<Nothing> (*ConfigureSocketCallback)(
+      SSL* ssl,
+      const Address& peer,
+      const Option<std::string>& servername);
+
+  // Callback that will be called after the TLS handshake has been
+  // completed successfully.
+  typedef Try<Nothing> (*VerifyCallback)(
+      const SSL* const ssl,
+      const Option<std::string>& servername,
+      const Option<net::IP>& ip);
+
+  // The `ConfigureSocketCallback` and `VerifyCallback` arguments can be set
+  // to nullptr, in that case they will not be called.
+  TLSClientConfig(
+      const Option<std::string>& servername,
+      SSL_CTX *ctx,
+      ConfigureSocketCallback,
+      VerifyCallback);
+
+  // Context from which the `SSL` object for this connection is created.
+  SSL_CTX *ctx;
+
+  // Server hostname to be used for hostname validation, if any.
+  // This will be passed as the `servername` argument to both
+  // callbacks.
+  //
+  // TODO(bevers): Use this for SNI as well when the linked OpenSSL
+  // supports it.
+  Option<std::string> servername;
+
+  // User-specified callbacks.
+  VerifyCallback verify;
+  ConfigureSocketCallback configure_socket;
+};
+
+
+// Returns a `TLSClientConfig` object that is configured with the
+// provided `servername` and the global libprocess SSL context. The
+// callbacks `verify` and `configure_socket` are setup with a pair
+// default functions that implement the SSL behaviour configured
+// via the `LIBPROCESS_SSL_*` environment variables.
+//
+// NOTE: Callers must _NOT_ modify the `ctx` in the returned `TLSClientConfig`.
+// Doing so would mutate global libprocess state.
+//
+// NOTE: The returned `ctx`, `verify` and `configure_socket` values all
+// implement parts of the libprocess default behaviour and rely on each other
+// for working correctly. It is not recommended to change one of them while
+// keeping the others, unless you know *exactly* what you're doing.
+//
+// NOTE: The passed `servername` will be ignored and a reverse DNS lookup will
+// be done instead if `LIBPROCESS_SSL_HOSTNAME_VALIDATION_SCHEME=legacy`.
+TLSClientConfig create_tls_client_config(const Option<std::string>& servername);
+
+} // namespace openssl {
+} // namespace network {
+} // namespace process {
+
+#endif // USE_SSL_SOCKET
+
+#endif // __PROCESS_SSL_TLS_CONFIG_HPP__
diff --git a/3rdparty/libprocess/src/http.cpp b/3rdparty/libprocess/src/http.cpp
index 0ed9aa8..0ed41aa 100644
--- a/3rdparty/libprocess/src/http.cpp
+++ b/3rdparty/libprocess/src/http.cpp
@@ -46,6 +46,8 @@
 #include <process/socket.hpp>
 #include <process/state_machine.hpp>
 
+#include <process/ssl/tls_config.hpp>
+
 #include <stout/error.hpp>
 #include <stout/foreach.hpp>
 #include <stout/ip.hpp>
@@ -1449,7 +1451,21 @@ Future<Connection> connect(
     return Failure("Failed to create socket: " + socket.error());
   }
 
-  return socket->connect(address, peer_hostname)
+  Future<Nothing> connected = [&]() {
+    switch (scheme) {
+      case Scheme::HTTP:
+        return socket->connect(address);
+#ifdef USE_SSL_SOCKET
+      case Scheme::HTTPS:
+        return socket->connect(
+            address,
+            network::openssl::create_tls_client_config(peer_hostname));
+#endif
+    }
+    UNREACHABLE();
+  }();
+
+  return connected
     .then([socket, address]() -> Future<Connection> {
       Try<network::Address> localAddress = socket->address();
       if (localAddress.isError()) {
diff --git a/3rdparty/libprocess/src/openssl.cpp b/3rdparty/libprocess/src/openssl.cpp
index fb03032..850b6f7 100644
--- a/3rdparty/libprocess/src/openssl.cpp
+++ b/3rdparty/libprocess/src/openssl.cpp
@@ -29,6 +29,7 @@
 #include <process/once.hpp>
 
 #include <process/ssl/flags.hpp>
+#include <process/ssl/tls_config.hpp>
 
 #include <stout/os.hpp>
 #include <stout/strings.hpp>
@@ -802,10 +803,11 @@ Try<Nothing> verify(
     return Try<Nothing>(Nothing());
   }
 
-  // For backwards compatibility, the 'libprocess' scheme will attempt to get
-  // the peer hostname using a reverse DNS lookup if connecting via IP address.
+  // NOTE: For backwards compatibility, we ignore the passed hostname here,
+  // i.e. the 'legacy' hostname validation scheme will always attempt to get
+  // the peer hostname using a reverse DNS lookup.
   Option<std::string> peer_hostname = hostname;
-  if (!hostname.isSome() && ip.isSome()) {
+  if (ip.isSome()) {
     VLOG(1) << "Doing rDNS lookup for 'libprocess' hostname validation";
     Stopwatch watch;
 
@@ -1042,6 +1044,49 @@ Try<Nothing> configure_socket(
   return Nothing();
 }
 
+
+// Wrappers to be able to use the above `verify()` and `configure_socket()`
+// inside a `TLSClientConfig` struct.
+Try<Nothing> client_verify(
+    const SSL* const ssl,
+    const Option<std::string>& hostname,
+    const Option<net::IP>& ip)
+{
+  return verify(ssl, Mode::CLIENT, hostname, ip);
+}
+
+
+Try<Nothing> client_configure_socket(
+    SSL* ssl,
+    const Address& peer,
+    const Option<std::string>& peer_hostname)
+{
+  return configure_socket(ssl, Mode::CLIENT, peer, peer_hostname);
+}
+
+
+TLSClientConfig::TLSClientConfig(
+    const Option<std::string>& servername,
+    SSL_CTX *ctx,
+    ConfigureSocketCallback configure_socket,
+    VerifyCallback verify)
+  : ctx(ctx),
+    servername(servername),
+    verify(verify),
+    configure_socket(configure_socket)
+{}
+
+
+TLSClientConfig create_tls_client_config(
+    const Option<std::string>& servername)
+{
+  return TLSClientConfig(
+      servername,
+      openssl::ctx,
+      &client_configure_socket,
+      &client_verify);
+}
+
 } // namespace openssl {
 } // namespace network {
 } // namespace process {
diff --git a/3rdparty/libprocess/src/openssl.hpp b/3rdparty/libprocess/src/openssl.hpp
index d4ddbff..271cc95 100644
--- a/3rdparty/libprocess/src/openssl.hpp
+++ b/3rdparty/libprocess/src/openssl.hpp
@@ -30,6 +30,8 @@
 
 #include <process/network.hpp>
 
+#include <process/ssl/tls_config.hpp>
+
 namespace process {
 namespace network {
 namespace openssl {
diff --git a/3rdparty/libprocess/src/poll_socket.hpp b/3rdparty/libprocess/src/poll_socket.hpp
index c60e454..881dab1 100644
--- a/3rdparty/libprocess/src/poll_socket.hpp
+++ b/3rdparty/libprocess/src/poll_socket.hpp
@@ -14,6 +14,8 @@
 
 #include <process/socket.hpp>
 
+#include <process/ssl/tls_config.hpp>
+
 #include <stout/try.hpp>
 
 namespace process {
@@ -33,8 +35,12 @@ public:
   Try<Nothing> listen(int backlog) override;
   Future<std::shared_ptr<SocketImpl>> accept() override;
   Future<Nothing> connect(
+      const Address& address) override;
+#ifdef USE_SSL_SOCKET
+  Future<Nothing> connect(
       const Address& address,
-      const Option<std::string>& peer_hostname) override;
+      const openssl::TLSClientConfig& config) override;
+#endif
   Future<size_t> recv(char* data, size_t size) override;
   Future<size_t> send(const char* data, size_t size) override;
   Future<size_t> sendfile(int_fd fd, off_t offset, size_t size) override;
diff --git a/3rdparty/libprocess/src/posix/libevent/libevent_ssl_socket.cpp b/3rdparty/libprocess/src/posix/libevent/libevent_ssl_socket.cpp
index 8f3d8d9..dcb6d8e 100644
--- a/3rdparty/libprocess/src/posix/libevent/libevent_ssl_socket.cpp
+++ b/3rdparty/libprocess/src/posix/libevent/libevent_ssl_socket.cpp
@@ -437,25 +437,23 @@ void LibeventSSLSocketImpl::event_callback(short events)
     // If we're connecting, then we've succeeded. Time to do
     // post-verification.
     CHECK_NOTNULL(bev);
-
-    // Do post-validation of connection.
-    SSL* ssl = bufferevent_openssl_get_ssl(bev);
-
-    // We intentionally don't store the hostname passed to
-    // `connect()`: The 'openssl' hostname validation already
-    // verified the hostname if we get here, and the 'libprocess'
-    // algorithm should always use rDNS lookups on the IP address
-    // for backwards compatibility with the previous behaviour.
-    Try<Nothing> verify = openssl::verify(
-        ssl, Mode::CLIENT, None(), peer_ip);
-
-    if (verify.isError()) {
-      VLOG(1) << "Failed connect, verification error: " << verify.error();
-      SSL_free(ssl);
-      bufferevent_free(bev);
-      bev = nullptr;
-      current_connect_request->promise.fail(verify.error());
-      return;
+    CHECK(client_config.isSome());
+
+    if (client_config->verify) {
+      // Do post-validation of connection.
+      SSL* ssl = bufferevent_openssl_get_ssl(bev);
+
+      Try<Nothing> verify = client_config->verify(
+          ssl, client_config->servername, peer_ip);
+
+      if (verify.isError()) {
+        VLOG(1) << "Failed connect, verification error: " << verify.error();
+        SSL_free(ssl);
+        bufferevent_free(bev);
+        bev = nullptr;
+        current_connect_request->promise.fail(verify.error());
+        return;
+      }
     }
 
     current_connect_request->promise.set(Nothing());
@@ -515,8 +513,15 @@ LibeventSSLSocketImpl::LibeventSSLSocketImpl(
 
 
 Future<Nothing> LibeventSSLSocketImpl::connect(
+    const Address& address)
+{
+  LOG(FATAL) << "No TLS config was passed to a SSL socket.";
+}
+
+
+Future<Nothing> LibeventSSLSocketImpl::connect(
     const Address& address,
-    const Option<string>& peer_hostname)
+    const openssl::TLSClientConfig& config)
 {
   if (bev != nullptr) {
     return Failure("Socket is already connected");
@@ -526,16 +531,24 @@ Future<Nothing> LibeventSSLSocketImpl::connect(
     return Failure("Socket is already connecting");
   }
 
-  SSL* ssl = SSL_new(openssl::context());
+  if (config.ctx == nullptr) {
+    return Failure("Invalid SSL context");
+  }
+
+  SSL* ssl = SSL_new(config.ctx);
   if (ssl == nullptr) {
     return Failure("Failed to connect: SSL_new");
   }
 
-  Try<Nothing> configured = openssl::configure_socket(
-      ssl, openssl::Mode::CLIENT, address, peer_hostname);
+  client_config = config;
 
-  if (configured.isError()) {
-    return Failure("Failed to configure socket: " + configured.error());
+  if (config.configure_socket) {
+    Try<Nothing> configured = config.configure_socket(
+        ssl, address, config.servername);
+
+    if (configured.isError()) {
+      return Failure("Failed to configure socket: " + configured.error());
+    }
   }
 
   // Construct the bufferevent in the connecting state.
@@ -565,8 +578,8 @@ Future<Nothing> LibeventSSLSocketImpl::connect(
     peer_ip = inetAddress.ip;
   }
 
-  if (peer_hostname.isSome()) {
-    VLOG(2) << "Connecting to " << peer_hostname.get() << " at " << address;
+  if (config.servername.isSome()) {
+    VLOG(2) << "Connecting to " << config.servername.get() << " at " << address;
   } else {
     VLOG(2) << "Connecting to " << address << " with no hostname specified";
   }
@@ -618,6 +631,7 @@ Future<Nothing> LibeventSSLSocketImpl::connect(
             SSL* ssl = bufferevent_openssl_get_ssl(CHECK_NOTNULL(self->bev));
             SSL_free(ssl);
             bufferevent_free(self->bev);
+
             self->bev = nullptr;
 
             Owned<ConnectRequest> request;
diff --git a/3rdparty/libprocess/src/posix/libevent/libevent_ssl_socket.hpp b/3rdparty/libprocess/src/posix/libevent/libevent_ssl_socket.hpp
index b781d6a..7bcc66f 100644
--- a/3rdparty/libprocess/src/posix/libevent/libevent_ssl_socket.hpp
+++ b/3rdparty/libprocess/src/posix/libevent/libevent_ssl_socket.hpp
@@ -40,9 +40,10 @@ public:
   ~LibeventSSLSocketImpl() override;
 
   // Implement 'SocketImpl' interface.
+  Future<Nothing> connect(const Address& address) override;
   Future<Nothing> connect(
       const Address& address,
-      const Option<std::string>& peer_hostname) override;
+      const openssl::TLSClientConfig& config) override;
 
   Future<size_t> recv(char* data, size_t size) override;
   // Send does not currently support discard. See implementation.
@@ -190,6 +191,7 @@ private:
   Queue<Future<std::shared_ptr<SocketImpl>>> accept_queue;
 
   Option<net::IP> peer_ip;
+  Option<openssl::TLSClientConfig> client_config;
 };
 
 } // namespace internal {
diff --git a/3rdparty/libprocess/src/posix/poll_socket.cpp b/3rdparty/libprocess/src/posix/poll_socket.cpp
index 96c8df6..ecc2bd4 100644
--- a/3rdparty/libprocess/src/posix/poll_socket.cpp
+++ b/3rdparty/libprocess/src/posix/poll_socket.cpp
@@ -114,8 +114,7 @@ Future<std::shared_ptr<SocketImpl>> PollSocketImpl::accept()
 
 
 Future<Nothing> PollSocketImpl::connect(
-    const Address& address,
-    const Option<string>& /* peer_hostname */)
+    const Address& address)
 {
   Try<Nothing, SocketError> connect = network::connect(get(), address);
   if (connect.isError()) {
@@ -160,6 +159,14 @@ Future<Nothing> PollSocketImpl::connect(
   return Nothing();
 }
 
+#ifdef USE_SSL_SOCKET
+Future<Nothing> PollSocketImpl::connect(
+    const Address& address,
+    const openssl::TLSClientConfig& config)
+{
+  LOG(FATAL) << "TLS config was passed to a PollSocket.";
+}
+#endif
 
 Future<size_t> PollSocketImpl::recv(char* data, size_t size)
 {
diff --git a/3rdparty/libprocess/src/process.cpp b/3rdparty/libprocess/src/process.cpp
index d50f88d..1de6da2 100644
--- a/3rdparty/libprocess/src/process.cpp
+++ b/3rdparty/libprocess/src/process.cpp
@@ -150,6 +150,10 @@ using process::network::inet::Socket;
 
 using process::network::internal::SocketImpl;
 
+#ifdef USE_SSL_SOCKET
+using process::network::openssl::create_tls_client_config;
+#endif
+
 using std::deque;
 using std::find;
 using std::list;
@@ -1443,6 +1447,25 @@ void ignore_recv_data(
 // Forward declaration.
 void send(Encoder* encoder, Socket socket);
 
+// A helper to securely select the correct overload of `connect()`
+// for a generic socket.
+Future<Nothing> connectSocket(
+    Socket& socket,
+    const Address& address,
+    const Option<string>& servername)
+{
+  switch (socket.kind()) {
+    case SocketImpl::Kind::POLL:
+      return socket.connect(address);
+#ifdef USE_SSL_SOCKET
+    case SocketImpl::Kind::SSL:
+      return socket.connect(
+          address, create_tls_client_config(servername));
+#endif
+  }
+
+  UNREACHABLE();
+}
 
 } // namespace internal {
 
@@ -1671,7 +1694,7 @@ void SocketManager::link(
 
   if (connect) {
     CHECK_SOME(socket);
-    socket->connect(to.address, to.host)
+    internal::connectSocket(*socket, to.address, to.host)
       .onAny(lambda::bind(
           &SocketManager::link_connect,
           this,
@@ -2033,7 +2056,7 @@ void SocketManager::send(Message&& message, const SocketImpl::Kind& kind)
 
   if (connect) {
     CHECK_SOME(socket);
-    socket->connect(address, message.to.host)
+    internal::connectSocket(*socket, address, message.to.host)
       .onAny(lambda::bind(
             // TODO(benh): with C++14 we can use lambda instead of
             // `std::bind` and capture `message` with a `std::move`.
diff --git a/3rdparty/libprocess/src/tests/http_tests.cpp b/3rdparty/libprocess/src/tests/http_tests.cpp
index 4d37294..8cb5f16 100644
--- a/3rdparty/libprocess/src/tests/http_tests.cpp
+++ b/3rdparty/libprocess/src/tests/http_tests.cpp
@@ -40,6 +40,7 @@
 #include <process/socket.hpp>
 
 #include <process/ssl/gtest.hpp>
+#include <process/ssl/tls_config.hpp>
 
 #include <stout/base64.hpp>
 #include <stout/gtest.hpp>
@@ -260,7 +261,20 @@ TEST_P(HTTPTest, Endpoints)
 
     inet::Socket socket = create.get();
 
-    AWAIT_READY(socket.connect(http.process->self().address));
+    Future<Nothing> connected = [&]() {
+      switch(socket.kind()) {
+        case network::internal::SocketImpl::Kind::POLL:
+          return socket.connect(http.process->self().address);
+#ifdef USE_SSL_SOCKET
+        case network::internal::SocketImpl::Kind::SSL:
+          return socket.connect(
+              http.process->self().address,
+              network::openssl::create_tls_client_config(None()));
+#endif
+      }
+      UNREACHABLE();
+    }();
+    AWAIT_READY(connected);
 
     std::ostringstream out;
     out << "GET /" << http.process->self().id << "/body"
diff --git a/3rdparty/libprocess/src/tests/socket_tests.cpp b/3rdparty/libprocess/src/tests/socket_tests.cpp
index b09ae23..13e757a 100644
--- a/3rdparty/libprocess/src/tests/socket_tests.cpp
+++ b/3rdparty/libprocess/src/tests/socket_tests.cpp
@@ -54,6 +54,29 @@ void reinitialize(
 
 } // namespace process {
 
+
+// Helper function to safely connect a socket using the correct overload
+// of `connect()`.
+template<typename T, typename AddressType>
+static Future<Nothing> connectSocket(
+    process::network::internal::Socket<T>& socket,
+    const AddressType& address)
+{
+  switch (socket.kind()) {
+    case process::network::internal::SocketImpl::Kind::POLL:
+      return socket.connect(address);
+#ifdef USE_SSL_SOCKET
+    case process::network::internal::SocketImpl::Kind::SSL:
+      // The tests below never define an appropriate hostname to use, thus
+      // relying implicitly on the 'legacy' hostname validation scheme.
+      return socket.connect(
+          address,
+          process::network::openssl::create_tls_client_config(None()));
+#endif
+  }
+  UNREACHABLE();
+}
+
 class SocketTest : public TemporaryDirectoryTest {};
 
 #ifndef __WINDOWS__
@@ -76,7 +99,7 @@ TEST_F(SocketTest, Unix)
 
   Future<unix::Socket> accept = server->accept();
 
-  AWAIT_READY(client->connect(address.get()));
+  AWAIT_READY(connectSocket(*client, address.get()));
   AWAIT_READY(accept);
 
   unix::Socket socket = accept.get();
@@ -181,8 +204,9 @@ TEST_P(NetSocketTest, EOFBeforeRecv)
   // invalid address, except when used to resolve a host's address
   // for the first time.
   // See: https://tools.ietf.org/html/rfc1122#section-3.2.1.3
-  AWAIT_READY(
-      client->connect(Address(process::address().ip, server_address->port)));
+  AWAIT_READY(connectSocket(
+      *client,
+      Address(process::address().ip, server_address->port)));
 
   AWAIT_READY(server_accept);
 
@@ -226,8 +250,8 @@ TEST_P(NetSocketTest, EOFAfterRecv)
   // invalid address, except when used to resolve a host's address
   // for the first time.
   // See: https://tools.ietf.org/html/rfc1122#section-3.2.1.3
-  AWAIT_READY(
-      client->connect(Address(process::address().ip, server_address->port)));
+  AWAIT_READY(connectSocket(
+      *client, Address(process::address().ip, server_address->port)));
 
   AWAIT_READY(server_accept);
 
diff --git a/3rdparty/libprocess/src/tests/ssl_client.cpp b/3rdparty/libprocess/src/tests/ssl_client.cpp
index de87b3b..afa0a9c 100644
--- a/3rdparty/libprocess/src/tests/ssl_client.cpp
+++ b/3rdparty/libprocess/src/tests/ssl_client.cpp
@@ -143,8 +143,18 @@ TEST_F(SSLClientTest, client)
   EXPECT_SOME(ip);
 
   // Connect to the server socket located at `ip:port`.
-  const Future<Nothing> connect =
-    socket.connect(Address(ip.get(), flags.port));
+  Address address(ip.get(), flags.port);
+  Future<Nothing> connect = [&]() {
+    switch(socket.kind()) {
+      case SocketImpl::Kind::POLL:
+        return socket.connect(address);
+      case SocketImpl::Kind::SSL:
+        return socket.connect(
+            address,
+            openssl::create_tls_client_config(None()));
+    }
+    UNREACHABLE();
+  }();
 
   // Verify that the client views the connection as established.
   AWAIT_EXPECT_READY(connect);
diff --git a/3rdparty/libprocess/src/tests/ssl_tests.cpp b/3rdparty/libprocess/src/tests/ssl_tests.cpp
index e52451e..e1790d5 100644
--- a/3rdparty/libprocess/src/tests/ssl_tests.cpp
+++ b/3rdparty/libprocess/src/tests/ssl_tests.cpp
@@ -554,7 +554,11 @@ TEST_F(SSLTest, PeerAddress)
   const Try<Address> server_address = server->address();
   ASSERT_SOME(server_address);
 
-  const Future<Nothing> connect = client.connect(server_address.get());
+  // Pass `None()` as hostname because this test is still
+  // using the 'legacy' hostname validation scheme.
+  const Future<Nothing> connect = client.connect(
+      server_address.get(),
+      openssl::create_tls_client_config(None()));
 
   AWAIT_ASSERT_READY(socket);
   AWAIT_ASSERT_READY(connect);
@@ -734,7 +738,12 @@ TEST_F(SSLTest, ShutdownThenSend)
 
   Try<Socket> client = Socket::create(SocketImpl::Kind::SSL);
   ASSERT_SOME(client);
-  AWAIT_ASSERT_READY(client->connect(server->address().get()));
+
+  // Pass `None()` as hostname because this test is still
+  // using the 'legacy' hostname validation scheme.
+  AWAIT_ASSERT_READY(client->connect(
+      server->address().get(),
+      openssl::create_tls_client_config(None())));
 
   AWAIT_ASSERT_READY(socket);
 
@@ -787,7 +796,11 @@ TEST_P(SSLVerifyIPAddTest, BasicSameProcess)
 
   Future<Socket> accept = server->accept();
 
-  AWAIT_ASSERT_READY(client->connect(address.get()));
+  // Pass `None()` as hostname because this test is still
+  // using the 'legacy' hostname validation scheme.
+  AWAIT_ASSERT_READY(client->connect(
+      address.get(),
+      openssl::create_tls_client_config(None())));
 
   // Wait for the server to have accepted the client connection.
   AWAIT_ASSERT_READY(accept);
@@ -841,7 +854,11 @@ TEST_P(SSLVerifyIPAddTest, BasicSameProcessUnix)
 
   Future<unix::Socket> accept = server->accept();
 
-  AWAIT_ASSERT_READY(client->connect(address.get()));
+  // Pass `None()` as hostname because this test is still
+  // using the 'legacy' hostname validation scheme.
+  AWAIT_ASSERT_READY(client->connect(
+      address.get(),
+      openssl::create_tls_client_config(None())));
 
   // Wait for the server to have accepted the client connection.
   AWAIT_ASSERT_READY(accept);
@@ -972,6 +989,153 @@ TEST_P(SSLProtocolTest, Mismatch)
     Future<Socket> socket = server->accept();
     AWAIT_ASSERT_FAILED(socket);
 
+    // Pass `None()` as hostname because this test is still
+    // using the 'legacy' hostname validation scheme.
     AWAIT_ASSERT_READY(await_subprocess(client.get(), None()));
   }
 }
+
+
+// Verify that we can make a connection using a custom SSL context,
+// and that the specified `verify` and `configure_socket` callbacks
+// are called.
+TEST_F(SSLTest, CustomSSLContext)
+{
+  static bool verify_called;
+  static bool configure_socket_called;
+
+  os::setenv("LIBPROCESS_SSL_ENABLED", "true");
+  os::setenv("LIBPROCESS_SSL_KEY_FILE", key_path().string());
+  os::setenv("LIBPROCESS_SSL_CERT_FILE", certificate_path().string());
+
+  openssl::reinitialize();
+
+  verify_called = false;
+  configure_socket_called = false;
+
+  SSL_CTX* ctx = SSL_CTX_new(SSLv23_client_method());
+  SSL_CTX_set_verify(ctx, SSL_VERIFY_NONE, nullptr);
+
+  openssl::TLSClientConfig config(
+    None(),
+    ctx,
+    [](SSL*, const network::Address&, const Option<std::string>&)
+      -> Try<Nothing>
+    {
+      configure_socket_called = true;
+      return Nothing();
+    },
+    [](const SSL* const, const Option<std::string>&, const Option<net::IP>&)
+      -> Try<Nothing>
+    {
+      verify_called = true;
+      return Nothing();
+    });
+
+  Try<Socket> client = Socket::create(SocketImpl::Kind::SSL);
+  ASSERT_SOME(client);
+
+  Try<Socket> server = Socket::create(SocketImpl::Kind::SSL);
+  ASSERT_SOME(server);
+
+  server->listen(1);
+  Try<Address> address = server->address();
+  ASSERT_SOME(address);
+
+  Future<Socket> socket = server->accept();
+  Future<Nothing> connected = client->connect(*address, config);
+
+  AWAIT_READY(socket);
+  AWAIT_READY(connected);
+
+  EXPECT_TRUE(verify_called);
+  EXPECT_TRUE(configure_socket_called);
+}
+
+
+// Ensures that `connect()` fails if the passed
+// `configure_socket` callback returns an error.
+TEST_F(SSLTest, CustomSSLContextConfigureSocketFails)
+{
+  os::setenv("LIBPROCESS_SSL_ENABLED", "true");
+  os::setenv("LIBPROCESS_SSL_KEY_FILE", key_path().string());
+  os::setenv("LIBPROCESS_SSL_CERT_FILE", certificate_path().string());
+
+  openssl::reinitialize();
+
+  SSL_CTX* ctx = SSL_CTX_new(SSLv23_client_method());
+  SSL_CTX_set_verify(ctx, SSL_VERIFY_NONE, nullptr);
+
+  openssl::TLSClientConfig config(
+    None(),
+    ctx,
+    [](SSL*, const network::Address&, const Option<std::string>&)
+      -> Try<Nothing>
+    {
+      return Error("Configure socket.");
+    },
+    [](const SSL* const, const Option<std::string>&, const Option<net::IP>&)
+      -> Try<Nothing>
+    {
+      return Nothing();
+    });
+
+  Try<Socket> client = Socket::create(SocketImpl::Kind::SSL);
+  ASSERT_SOME(client);
+
+  Try<Socket> server = Socket::create(SocketImpl::Kind::SSL);
+  ASSERT_SOME(server);
+
+  server->listen(1);
+  Try<Address> address = server->address();
+  ASSERT_SOME(address);
+
+  Future<Socket> socket = server->accept();
+  Future<Nothing> connected = client->connect(*address, config);
+
+  AWAIT_ASSERT_FAILED(connected);
+}
+
+
+// Ensures that `connect()` fails if the passed
+// `verify` callback returns an error.
+TEST_F(SSLTest, CustomSSLContextVerifyFails)
+{
+  os::setenv("LIBPROCESS_SSL_ENABLED", "true");
+  os::setenv("LIBPROCESS_SSL_KEY_FILE", key_path().string());
+  os::setenv("LIBPROCESS_SSL_CERT_FILE", certificate_path().string());
+
+  openssl::reinitialize();
+
+  SSL_CTX* ctx = SSL_CTX_new(SSLv23_client_method());
+  SSL_CTX_set_verify(ctx, SSL_VERIFY_NONE, nullptr);
+
+  openssl::TLSClientConfig config(
+    None(),
+    ctx,
+    [](SSL*, const network::Address&, const Option<std::string>&)
+      -> Try<Nothing>
+    {
+      return Nothing();
+    },
+    [](const SSL* const, const Option<std::string>&, const Option<net::IP>&)
+      -> Try<Nothing>
+    {
+      return Error("Verify failed.");
+    });
+
+  Try<Socket> client = Socket::create(SocketImpl::Kind::SSL);
+  ASSERT_SOME(client);
+
+  Try<Socket> server = Socket::create(SocketImpl::Kind::SSL);
+  ASSERT_SOME(server);
+
+  server->listen(1);
+  Try<Address> address = server->address();
+  ASSERT_SOME(address);
+
+  Future<Socket> socket = server->accept();
+  Future<Nothing> connected = client->connect(*address, config);
+
+  AWAIT_ASSERT_FAILED(connected);
+}