You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by tr...@apache.org on 2018/07/16 07:03:10 UTC

[5/8] flink git commit: [FLINK-9313] [security] (part 1) Instantiate all SSLSocket and SSLServerSocket through factories.

[FLINK-9313] [security] (part 1) Instantiate all SSLSocket and SSLServerSocket through factories.

This removes hostname verification from SSL client sockets.
With client authentication, this is no longer needed and it is not compatible with
various container environments.


Project: http://git-wip-us.apache.org/repos/asf/flink/repo
Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/4db63c03
Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/4db63c03
Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/4db63c03

Branch: refs/heads/master
Commit: 4db63c0379fc1e0d8aab573d3d9d6dfa68074d76
Parents: 7a912e6
Author: Stephan Ewen <se...@apache.org>
Authored: Thu Jul 12 11:28:57 2018 +0200
Committer: Till Rohrmann <tr...@apache.org>
Committed: Mon Jul 16 08:10:46 2018 +0200

----------------------------------------------------------------------
 .../apache/flink/runtime/blob/BlobClient.java   | 39 +++-----
 .../apache/flink/runtime/blob/BlobServer.java   | 52 ++++------
 .../org/apache/flink/runtime/net/SSLUtils.java  | 99 +++++++++++++++-----
 .../apache/flink/runtime/net/SSLUtilsTest.java  | 59 +++++++-----
 4 files changed, 143 insertions(+), 106 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/flink/blob/4db63c03/flink-runtime/src/main/java/org/apache/flink/runtime/blob/BlobClient.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/blob/BlobClient.java b/flink-runtime/src/main/java/org/apache/flink/runtime/blob/BlobClient.java
index 80e36b6..2ca250c 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/blob/BlobClient.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/blob/BlobClient.java
@@ -21,6 +21,7 @@ package org.apache.flink.runtime.blob;
 import org.apache.flink.api.common.JobID;
 import org.apache.flink.configuration.BlobServerOptions;
 import org.apache.flink.configuration.Configuration;
+import org.apache.flink.configuration.SecurityOptions;
 import org.apache.flink.core.fs.FileSystem;
 import org.apache.flink.core.fs.Path;
 import org.apache.flink.runtime.net.SSLUtils;
@@ -30,9 +31,6 @@ import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
 import javax.annotation.Nullable;
-import javax.net.ssl.SSLContext;
-import javax.net.ssl.SSLParameters;
-import javax.net.ssl.SSLSocket;
 
 import java.io.Closeable;
 import java.io.EOFException;
@@ -68,7 +66,7 @@ public final class BlobClient implements Closeable {
 	private static final Logger LOG = LoggerFactory.getLogger(BlobClient.class);
 
 	/** The socket connection to the BLOB server. */
-	private Socket socket;
+	private final Socket socket;
 
 	/**
 	 * Instantiates a new BLOB client.
@@ -82,41 +80,28 @@ public final class BlobClient implements Closeable {
 	 *         thrown if the connection to the BLOB server could not be established
 	 */
 	public BlobClient(InetSocketAddress serverAddress, Configuration clientConfig) throws IOException {
+		Socket socket = null;
 
 		try {
-			// Check if ssl is enabled
-			SSLContext clientSSLContext = null;
-			if (clientConfig != null &&
-				clientConfig.getBoolean(BlobServerOptions.SSL_ENABLED)) {
-
-				clientSSLContext = SSLUtils.createSSLClientContext(clientConfig);
-			}
-
-			if (clientSSLContext != null) {
-
+			// create an SSL socket if configured
+			if (clientConfig.getBoolean(SecurityOptions.SSL_ENABLED) && clientConfig.getBoolean(BlobServerOptions.SSL_ENABLED)) {
 				LOG.info("Using ssl connection to the blob server");
 
-				SSLSocket sslSocket = (SSLSocket) clientSSLContext.getSocketFactory().createSocket(
+				socket = SSLUtils.createSSLClientSocketFactory(clientConfig).createSocket(
 					serverAddress.getAddress(),
 					serverAddress.getPort());
-
-				// Enable hostname verification for remote SSL connections
-				if (!serverAddress.getAddress().isLoopbackAddress()) {
-					SSLParameters newSSLParameters = sslSocket.getSSLParameters();
-					SSLUtils.setSSLVerifyHostname(clientConfig, newSSLParameters);
-					sslSocket.setSSLParameters(newSSLParameters);
-				}
-				this.socket = sslSocket;
-			} else {
-				this.socket = new Socket();
-				this.socket.connect(serverAddress);
 			}
-
+			else {
+				socket = new Socket();
+				socket.connect(serverAddress);
+			}
 		}
 		catch (Exception e) {
 			BlobUtils.closeSilently(socket, LOG);
 			throw new IOException("Could not connect to BlobServer at address " + serverAddress, e);
 		}
+
+		this.socket = socket;
 	}
 
 	/**

http://git-wip-us.apache.org/repos/asf/flink/blob/4db63c03/flink-runtime/src/main/java/org/apache/flink/runtime/blob/BlobServer.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/blob/BlobServer.java b/flink-runtime/src/main/java/org/apache/flink/runtime/blob/BlobServer.java
index dd0155c..ee1d50a 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/blob/BlobServer.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/blob/BlobServer.java
@@ -23,6 +23,7 @@ import org.apache.flink.api.common.JobID;
 import org.apache.flink.api.java.tuple.Tuple2;
 import org.apache.flink.configuration.BlobServerOptions;
 import org.apache.flink.configuration.Configuration;
+import org.apache.flink.configuration.SecurityOptions;
 import org.apache.flink.runtime.net.SSLUtils;
 import org.apache.flink.util.ExceptionUtils;
 import org.apache.flink.util.FileUtils;
@@ -33,7 +34,7 @@ import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
 import javax.annotation.Nullable;
-import javax.net.ssl.SSLContext;
+import javax.net.ServerSocketFactory;
 
 import java.io.File;
 import java.io.FileNotFoundException;
@@ -78,9 +79,6 @@ public class BlobServer extends Thread implements BlobService, BlobWriter, Perma
 	/** The server socket listening for incoming connections. */
 	private final ServerSocket serverSocket;
 
-	/** The SSL server context if ssl is enabled for the connections. */
-	private final SSLContext serverSSLContext;
-
 	/** Blob Server configuration. */
 	private final Configuration blobServiceConfiguration;
 
@@ -172,40 +170,30 @@ public class BlobServer extends Thread implements BlobService, BlobWriter, Perma
 
 		this.shutdownHook = ShutdownHookUtil.addShutdownHook(this, getClass().getSimpleName(), LOG);
 
-		if (config.getBoolean(BlobServerOptions.SSL_ENABLED)) {
-			try {
-				serverSSLContext = SSLUtils.createSSLServerContext(config);
-			} catch (Exception e) {
-				throw new IOException("Failed to initialize SSLContext for the blob server", e);
-			}
-		} else {
-			serverSSLContext = null;
-		}
-
 		//  ----------------------- start the server -------------------
 
-		String serverPortRange = config.getString(BlobServerOptions.PORT);
+		final String serverPortRange = config.getString(BlobServerOptions.PORT);
+		final Iterator<Integer> ports = NetUtils.getPortRangeFromString(serverPortRange);
 
-		Iterator<Integer> ports = NetUtils.getPortRangeFromString(serverPortRange);
+		final ServerSocketFactory socketFactory;
+		if (config.getBoolean(SecurityOptions.SSL_ENABLED) && config.getBoolean(BlobServerOptions.SSL_ENABLED)) {
+			try {
+				socketFactory = SSLUtils.createSSLServerSocketFactory(config);
+			}
+			catch (Exception e) {
+				throw new IOException("Failed to initialize SSL for the blob server", e);
+			}
+		}
+		else {
+			socketFactory = ServerSocketFactory.getDefault();
+		}
 
 		final int finalBacklog = backlog;
-		ServerSocket socketAttempt = NetUtils.createSocketFromPorts(ports, new NetUtils.SocketFactory() {
-			@Override
-			public ServerSocket createSocket(int port) throws IOException {
-				if (serverSSLContext == null) {
-					return new ServerSocket(port, finalBacklog);
-				} else {
-					LOG.info("Enabling ssl for the blob server");
-					return serverSSLContext.getServerSocketFactory().createServerSocket(port, finalBacklog);
-				}
-			}
-		});
+		this.serverSocket = NetUtils.createSocketFromPorts(ports,
+				(port) -> socketFactory.createServerSocket(port, finalBacklog));
 
-		if (socketAttempt == null) {
-			throw new IOException("Unable to allocate socket for blob server in specified port range: " + serverPortRange);
-		} else {
-			SSLUtils.setSSLVerAndCipherSuites(socketAttempt, config);
-			this.serverSocket = socketAttempt;
+		if (serverSocket == null) {
+			throw new IOException("Unable to open BLOB Server in specified port range: " + serverPortRange);
 		}
 
 		// start the server thread

http://git-wip-us.apache.org/repos/asf/flink/blob/4db63c03/flink-runtime/src/main/java/org/apache/flink/runtime/net/SSLUtils.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/net/SSLUtils.java b/flink-runtime/src/main/java/org/apache/flink/runtime/net/SSLUtils.java
index b574d30..2bfc0d6 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/net/SSLUtils.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/net/SSLUtils.java
@@ -19,6 +19,7 @@
 package org.apache.flink.runtime.net;
 
 import org.apache.flink.configuration.Configuration;
+import org.apache.flink.configuration.IllegalConfigurationException;
 import org.apache.flink.configuration.SecurityOptions;
 import org.apache.flink.util.Preconditions;
 
@@ -26,26 +27,31 @@ import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
 import javax.annotation.Nullable;
+import javax.net.ServerSocketFactory;
+import javax.net.SocketFactory;
 import javax.net.ssl.KeyManagerFactory;
 import javax.net.ssl.SSLContext;
 import javax.net.ssl.SSLEngine;
 import javax.net.ssl.SSLParameters;
 import javax.net.ssl.SSLServerSocket;
+import javax.net.ssl.SSLServerSocketFactory;
 import javax.net.ssl.TrustManagerFactory;
 
 import java.io.File;
 import java.io.FileInputStream;
+import java.io.IOException;
+import java.net.InetAddress;
 import java.net.ServerSocket;
 import java.security.KeyStore;
-import java.util.Arrays;
 
-import static java.util.Objects.requireNonNull;
+import static org.apache.flink.util.Preconditions.checkNotNull;
 import static org.apache.flink.util.Preconditions.checkState;
 
 /**
  * Common utilities to manage SSL transport settings.
  */
 public class SSLUtils {
+
 	private static final Logger LOG = LoggerFactory.getLogger(SSLUtils.class);
 
 	/**
@@ -56,33 +62,35 @@ public class SSLUtils {
 	 * @return true if global ssl flag is set
 	 */
 	public static boolean getSSLEnabled(Configuration sslConfig) {
-
-		Preconditions.checkNotNull(sslConfig);
-
 		return sslConfig.getBoolean(SecurityOptions.SSL_ENABLED);
 	}
 
 	/**
-	 * Sets SSl version and cipher suites for SSLServerSocket.
-	 * @param socket
-	 *        Socket to be handled
-	 * @param config
-	 *        The application configuration
+	 * Creates a factory for SSL Server Sockets from the given configuration.
 	 */
-	public static void setSSLVerAndCipherSuites(ServerSocket socket, Configuration config) {
-		if (socket instanceof SSLServerSocket) {
-			final String[] protocols = config.getString(SecurityOptions.SSL_PROTOCOL).split(",");
+	public static ServerSocketFactory createSSLServerSocketFactory(Configuration config) throws Exception {
+		SSLContext sslContext = createSSLServerContext(config);
+		if (sslContext == null) {
+			throw new IllegalConfigurationException("SSL is not enabled");
+		}
 
-			final String[] cipherSuites = config.getString(SecurityOptions.SSL_ALGORITHMS).split(",");
+		String[] protocols = getEnabledProtocols(config);
+		String[] cipherSuites = getEnabledCipherSuites(config);
 
-			if (LOG.isDebugEnabled()) {
-				LOG.debug("Configuring TLS version and cipher suites on SSL socket {} / {}",
-						Arrays.toString(protocols), Arrays.toString(cipherSuites));
-			}
+		SSLServerSocketFactory factory = sslContext.getServerSocketFactory();
+		return new ConfiguringSSLServerSocketFactory(factory, protocols, cipherSuites);
+	}
 
-			((SSLServerSocket) socket).setEnabledProtocols(protocols);
-			((SSLServerSocket) socket).setEnabledCipherSuites(cipherSuites);
+	/**
+	 * Creates a factory for SSL Client Sockets from the given configuration.
+	 */
+	public static SocketFactory createSSLClientSocketFactory(Configuration config) throws Exception {
+		SSLContext sslContext = createSSLServerContext(config);
+		if (sslContext == null) {
+			throw new IllegalConfigurationException("SSL is not enabled");
 		}
+
+		return sslContext.getSocketFactory();
 	}
 
 	/**
@@ -134,12 +142,12 @@ public class SSLUtils {
 	}
 
 	private static String[] getEnabledProtocols(final Configuration config) {
-		requireNonNull(config, "config must not be null");
+		checkNotNull(config, "config must not be null");
 		return config.getString(SecurityOptions.SSL_PROTOCOL).split(",");
 	}
 
 	private static String[] getEnabledCipherSuites(final Configuration config) {
-		requireNonNull(config, "config must not be null");
+		checkNotNull(config, "config must not be null");
 		return config.getString(SecurityOptions.SSL_ALGORITHMS).split(",");
 	}
 
@@ -259,4 +267,51 @@ public class SSLUtils {
 
 		return serverSSLContext;
 	}
+
+	// ------------------------------------------------------------------------
+	//  Wrappers for socket factories that additionally configure the sockets
+	// ------------------------------------------------------------------------
+
+	private static class ConfiguringSSLServerSocketFactory extends ServerSocketFactory {
+
+		private final SSLServerSocketFactory sslServerSocketFactory;
+		private final String[] protocols;
+		private final String[] cipherSuites;
+
+		ConfiguringSSLServerSocketFactory(
+				SSLServerSocketFactory sslServerSocketFactory,
+				String[] protocols,
+				String[] cipherSuites) {
+
+			this.sslServerSocketFactory = sslServerSocketFactory;
+			this.protocols = protocols;
+			this.cipherSuites = cipherSuites;
+		}
+
+		@Override
+		public ServerSocket createServerSocket(int port) throws IOException {
+			SSLServerSocket socket = (SSLServerSocket) sslServerSocketFactory.createServerSocket(port);
+			configureServerSocket(socket);
+			return socket;
+		}
+
+		@Override
+		public ServerSocket createServerSocket(int port, int backlog) throws IOException {
+			SSLServerSocket socket = (SSLServerSocket) sslServerSocketFactory.createServerSocket(port, backlog);
+			configureServerSocket(socket);
+			return socket;
+		}
+
+		@Override
+		public ServerSocket createServerSocket(int port, int backlog, InetAddress ifAddress) throws IOException {
+			SSLServerSocket socket = (SSLServerSocket) sslServerSocketFactory.createServerSocket(port, backlog, ifAddress);
+			configureServerSocket(socket);
+			return socket;
+		}
+
+		private void configureServerSocket(SSLServerSocket socket) {
+			socket.setEnabledProtocols(protocols);
+			socket.setEnabledCipherSuites(cipherSuites);
+		}
+	}
 }

http://git-wip-us.apache.org/repos/asf/flink/blob/4db63c03/flink-runtime/src/test/java/org/apache/flink/runtime/net/SSLUtilsTest.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/net/SSLUtilsTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/net/SSLUtilsTest.java
index 1bf3173..cdc121a 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/net/SSLUtilsTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/net/SSLUtilsTest.java
@@ -19,8 +19,10 @@
 package org.apache.flink.runtime.net;
 
 import org.apache.flink.configuration.Configuration;
+import org.apache.flink.configuration.IllegalConfigurationException;
 import org.apache.flink.configuration.SecurityOptions;
 
+import org.hamcrest.collection.IsArrayContainingInAnyOrder;
 import org.junit.Assert;
 import org.junit.Test;
 
@@ -33,7 +35,10 @@ import java.util.Arrays;
 
 import static org.hamcrest.Matchers.contains;
 import static org.hamcrest.Matchers.containsInAnyOrder;
+import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertThat;
+import static org.junit.Assert.assertTrue;
+import static org.junit.Assert.fail;
 
 /**
  * Tests for the {@link SSLUtils}.
@@ -157,44 +162,48 @@ public class SSLUtilsTest {
 		}
 	}
 
+	@Test
+	public void testSocketFactoriesWhenSslDisables() throws Exception {
+		Configuration config = new Configuration();
+
+		try {
+			SSLUtils.createSSLServerSocketFactory(config);
+			fail("exception expected");
+		} catch (IllegalConfigurationException ignored) {}
+
+		try {
+			SSLUtils.createSSLClientSocketFactory(config);
+			fail("exception expected");
+		} catch (IllegalConfigurationException ignored) {}
+	}
+
 	/**
 	 * Tests if SSLUtils set the right ssl version and cipher suites for SSLServerSocket.
 	 */
 	@Test
 	public void testSetSSLVersionAndCipherSuitesForSSLServerSocket() throws Exception {
-
 		Configuration serverConfig = new Configuration();
 		serverConfig.setBoolean(SecurityOptions.SSL_ENABLED, true);
 		serverConfig.setString(SecurityOptions.SSL_KEYSTORE, "src/test/resources/local127.keystore");
 		serverConfig.setString(SecurityOptions.SSL_KEYSTORE_PASSWORD, "password");
 		serverConfig.setString(SecurityOptions.SSL_KEY_PASSWORD, "password");
+
+		// set custom protocol and cipher suites
 		serverConfig.setString(SecurityOptions.SSL_PROTOCOL, "TLSv1.1");
 		serverConfig.setString(SecurityOptions.SSL_ALGORITHMS, "TLS_RSA_WITH_AES_128_CBC_SHA,TLS_RSA_WITH_AES_128_CBC_SHA256");
 
-		SSLContext serverContext = SSLUtils.createSSLServerContext(serverConfig);
-		ServerSocket socket = null;
-		try {
-			socket = serverContext.getServerSocketFactory().createServerSocket(0);
-
-			String[] protocols = ((SSLServerSocket) socket).getEnabledProtocols();
-			String[] algorithms = ((SSLServerSocket) socket).getEnabledCipherSuites();
-
-			Assert.assertNotEquals(1, protocols.length);
-			Assert.assertNotEquals(2, algorithms.length);
-
-			SSLUtils.setSSLVerAndCipherSuites(socket, serverConfig);
-			protocols = ((SSLServerSocket) socket).getEnabledProtocols();
-			algorithms = ((SSLServerSocket) socket).getEnabledCipherSuites();
-
-			Assert.assertEquals(1, protocols.length);
-			Assert.assertEquals("TLSv1.1", protocols[0]);
-			Assert.assertEquals(2, algorithms.length);
-			Assert.assertTrue(algorithms[0].equals("TLS_RSA_WITH_AES_128_CBC_SHA") || algorithms[0].equals("TLS_RSA_WITH_AES_128_CBC_SHA256"));
-			Assert.assertTrue(algorithms[1].equals("TLS_RSA_WITH_AES_128_CBC_SHA") || algorithms[1].equals("TLS_RSA_WITH_AES_128_CBC_SHA256"));
-		} finally {
-			if (socket != null) {
-				socket.close();
-			}
+		try (ServerSocket socket = SSLUtils.createSSLServerSocketFactory(serverConfig).createServerSocket(0)) {
+			assertTrue(socket instanceof SSLServerSocket);
+			final SSLServerSocket sslSocket = (SSLServerSocket) socket;
+
+			String[] protocols = sslSocket.getEnabledProtocols();
+			String[] algorithms = sslSocket.getEnabledCipherSuites();
+
+			assertEquals(1, protocols.length);
+			assertEquals("TLSv1.1", protocols[0]);
+			assertEquals(2, algorithms.length);
+			assertThat(algorithms, IsArrayContainingInAnyOrder.arrayContainingInAnyOrder(
+					"TLS_RSA_WITH_AES_128_CBC_SHA", "TLS_RSA_WITH_AES_128_CBC_SHA256"));
 		}
 	}