You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@nifi.apache.org by pv...@apache.org on 2021/10/02 16:31:15 UTC
[nifi] branch main updated: NIFI-9253 Corrected
SSLSocketChannel.available() for TLSv1.3
This is an automated email from the ASF dual-hosted git repository.
pvillard pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/nifi.git
The following commit(s) were added to refs/heads/main by this push:
new defea61 NIFI-9253 Corrected SSLSocketChannel.available() for TLSv1.3
defea61 is described below
commit defea610754a53abe71a2edd00c5485646d424f3
Author: exceptionfactory <ex...@apache.org>
AuthorDate: Tue Sep 28 17:00:17 2021 -0500
NIFI-9253 Corrected SSLSocketChannel.available() for TLSv1.3
- Added unit tests to reproduce issues with available() method
- Changed available() to return size of application buffer
- Removed unused isDataAvailable()
- Refactored unwrap handling to read from channel for buffer underflow
Signed-off-by: Pierre Villard <pi...@gmail.com>
This closes #5421.
---
.../remote/io/socket/ssl/SSLSocketChannel.java | 160 ++++++++-------------
.../io/socket/ssl/SSLSocketChannelInputStream.java | 4 -
.../remote/io/socket/ssl/SSLSocketChannelTest.java | 149 +++++++++++++------
3 files changed, 166 insertions(+), 147 deletions(-)
diff --git a/nifi-commons/nifi-security-socket-ssl/src/main/java/org/apache/nifi/remote/io/socket/ssl/SSLSocketChannel.java b/nifi-commons/nifi-security-socket-ssl/src/main/java/org/apache/nifi/remote/io/socket/ssl/SSLSocketChannel.java
index 9a5cdd8..beb0933 100644
--- a/nifi-commons/nifi-security-socket-ssl/src/main/java/org/apache/nifi/remote/io/socket/ssl/SSLSocketChannel.java
+++ b/nifi-commons/nifi-security-socket-ssl/src/main/java/org/apache/nifi/remote/io/socket/ssl/SSLSocketChannel.java
@@ -30,6 +30,7 @@ import javax.net.ssl.SSLException;
import javax.net.ssl.SSLHandshakeException;
import javax.net.ssl.SSLSession;
import java.io.Closeable;
+import java.io.EOFException;
import java.io.IOException;
import java.net.InetAddress;
import java.net.InetSocketAddress;
@@ -47,6 +48,7 @@ import java.util.concurrent.TimeUnit;
public class SSLSocketChannel implements Closeable {
private static final Logger LOGGER = LoggerFactory.getLogger(SSLSocketChannel.class);
+ private static final int MINIMUM_READ_BUFFER_SIZE = 1;
private static final int DISCARD_BUFFER_LENGTH = 8192;
private static final int END_OF_STREAM = -1;
private static final byte[] EMPTY_MESSAGE = new byte[0];
@@ -266,7 +268,7 @@ public class SSLSocketChannel implements Closeable {
status = wrapResult.getStatus();
}
if (Status.CLOSED == status) {
- final ByteBuffer streamOutputBuffer = streamOutManager.prepareForRead(1);
+ final ByteBuffer streamOutputBuffer = streamOutManager.prepareForRead(MINIMUM_READ_BUFFER_SIZE);
try {
writeChannel(streamOutputBuffer);
} catch (final IOException e) {
@@ -291,39 +293,8 @@ public class SSLSocketChannel implements Closeable {
* @throws IOException Thrown on failures checking for available bytes
*/
public int available() throws IOException {
- ByteBuffer appDataBuffer = appDataManager.prepareForRead(1);
- ByteBuffer streamDataBuffer = streamInManager.prepareForRead(1);
- final int buffered = appDataBuffer.remaining() + streamDataBuffer.remaining();
- if (buffered > 0) {
- return buffered;
- }
-
- if (!isDataAvailable()) {
- return 0;
- }
-
- appDataBuffer = appDataManager.prepareForRead(1);
- streamDataBuffer = streamInManager.prepareForRead(1);
- return appDataBuffer.remaining() + streamDataBuffer.remaining();
- }
-
- /**
- * Is data available for reading
- *
- * @return Data available status
- * @throws IOException Thrown on SocketChannel.read() failures
- */
- public boolean isDataAvailable() throws IOException {
- final ByteBuffer appDataBuffer = appDataManager.prepareForRead(1);
- final ByteBuffer streamDataBuffer = streamInManager.prepareForRead(1);
-
- if (appDataBuffer.remaining() > 0 || streamDataBuffer.remaining() > 0) {
- return true;
- }
-
- final ByteBuffer writableBuffer = streamInManager.prepareForWrite(engine.getSession().getPacketBufferSize());
- final int bytesRead = channel.read(writableBuffer);
- return (bytesRead > 0);
+ final ByteBuffer appDataBuffer = appDataManager.prepareForRead(MINIMUM_READ_BUFFER_SIZE);
+ return appDataBuffer.remaining();
}
/**
@@ -373,42 +344,24 @@ public class SSLSocketChannel implements Closeable {
}
appDataManager.clear();
- while (true) {
- final SSLEngineResult unwrapResult = unwrap();
-
- if (SSLEngineResult.HandshakeStatus.FINISHED == unwrapResult.getHandshakeStatus()) {
- // RFC 8446 Section 4.6 describes Post-Handshake Messages for TLS 1.3
- logOperation("Processing Post-Handshake Messages");
- continue;
+ final SSLEngineResult unwrapResult = unwrapBufferReadChannel();
+ final Status status = unwrapResult.getStatus();
+ if (Status.CLOSED == status) {
+ applicationBytesRead = readApplicationBuffer(buffer, offset, len);
+ if (applicationBytesRead == 0) {
+ return END_OF_STREAM;
}
-
- final Status status = unwrapResult.getStatus();
- switch (status) {
- case BUFFER_OVERFLOW:
- throw new IllegalStateException(String.format("SSLEngineResult Status [%s] not allowed from unwrap", status));
- case BUFFER_UNDERFLOW:
- final ByteBuffer streamBuffer = streamInManager.prepareForWrite(engine.getSession().getPacketBufferSize());
- final int channelBytesRead = readChannel(streamBuffer);
- logOperationBytes("Channel Read Completed", channelBytesRead);
- if (channelBytesRead == END_OF_STREAM) {
- return END_OF_STREAM;
- }
- break;
- case CLOSED:
- applicationBytesRead = readApplicationBuffer(buffer, offset, len);
- if (applicationBytesRead == 0) {
- return END_OF_STREAM;
- }
- streamInManager.compact();
- return applicationBytesRead;
- case OK:
- applicationBytesRead = readApplicationBuffer(buffer, offset, len);
- if (applicationBytesRead == 0) {
- throw new IOException("Read Application Buffer Failed");
- }
- streamInManager.compact();
- return applicationBytesRead;
+ streamInManager.compact();
+ return applicationBytesRead;
+ } else if (Status.OK == status) {
+ applicationBytesRead = readApplicationBuffer(buffer, offset, len);
+ if (applicationBytesRead == 0) {
+ throw new IOException("Read Application Buffer Failed");
}
+ streamInManager.compact();
+ return applicationBytesRead;
+ } else {
+ throw new IllegalStateException(String.format("SSLEngineResult Status [%s] not expected from unwrap", status));
}
}
@@ -508,24 +461,13 @@ public class SSLSocketChannel implements Closeable {
handshakeStatus = engine.getHandshakeStatus();
break;
case NEED_UNWRAP:
- final SSLEngineResult unwrapResult = unwrap();
+ final SSLEngineResult unwrapResult = unwrapBufferReadChannel();
handshakeStatus = unwrapResult.getHandshakeStatus();
- Status unwrapResultStatus = unwrapResult.getStatus();
-
- if (unwrapResultStatus == Status.BUFFER_UNDERFLOW) {
- final ByteBuffer writableDataIn = streamInManager.prepareForWrite(engine.getSession().getPacketBufferSize());
- final int bytesRead = readChannel(writableDataIn);
- logOperationBytes("Handshake Channel Read", bytesRead);
-
- if (bytesRead == END_OF_STREAM) {
- throw getHandshakeException(handshakeStatus, "End of Stream Found");
- }
- } else if (unwrapResultStatus == Status.CLOSED) {
+ if (unwrapResult.getStatus() == Status.CLOSED) {
throw getHandshakeException(handshakeStatus, "Channel Closed");
- } else {
- streamInManager.compact();
- appDataManager.clear();
}
+ streamInManager.compact();
+ appDataManager.clear();
break;
case NEED_WRAP:
final ByteBuffer outboundBuffer = streamOutManager.prepareForWrite(engine.getSession().getApplicationBufferSize());
@@ -536,7 +478,7 @@ public class SSLSocketChannel implements Closeable {
if (wrapResultStatus == Status.BUFFER_OVERFLOW) {
streamOutManager.prepareForWrite(engine.getSession().getApplicationBufferSize());
} else if (wrapResultStatus == Status.OK) {
- final ByteBuffer streamBuffer = streamOutManager.prepareForRead(1);
+ final ByteBuffer streamBuffer = streamOutManager.prepareForRead(MINIMUM_READ_BUFFER_SIZE);
final int bytesRemaining = streamBuffer.remaining();
writeChannel(streamBuffer);
logOperationBytes("Handshake Channel Write Completed", bytesRemaining);
@@ -549,8 +491,29 @@ public class SSLSocketChannel implements Closeable {
}
}
- private int readChannel(final ByteBuffer outputBuffer) throws IOException {
+ private SSLEngineResult unwrapBufferReadChannel() throws IOException {
+ SSLEngineResult unwrapResult = unwrap();
+
+ while (Status.BUFFER_UNDERFLOW == unwrapResult.getStatus()) {
+ final int channelBytesRead = readChannel();
+ if (channelBytesRead == END_OF_STREAM) {
+ throw new EOFException("End of Stream found for Channel Read");
+ }
+
+ unwrapResult = unwrap();
+ if (SSLEngineResult.HandshakeStatus.FINISHED == unwrapResult.getHandshakeStatus()) {
+ // RFC 8446 Section 4.6 describes Post-Handshake Messages for TLS 1.3
+ logOperation("Processing Post-Handshake Messages");
+ unwrapResult = unwrap();
+ }
+ }
+
+ return unwrapResult;
+ }
+
+ private int readChannel() throws IOException {
logOperation("Channel Read Started");
+ final ByteBuffer outputBuffer = streamInManager.prepareForWrite(engine.getSession().getPacketBufferSize());
final long started = System.currentTimeMillis();
long sleepNanoseconds = INITIAL_INCREMENTAL_SLEEP;
@@ -568,10 +531,24 @@ public class SSLSocketChannel implements Closeable {
continue;
}
+ logOperationBytes("Channel Read Completed", channelBytesRead);
return channelBytesRead;
}
}
+ private void readChannelDiscard() {
+ try {
+ final ByteBuffer readBuffer = ByteBuffer.allocate(DISCARD_BUFFER_LENGTH);
+ int bytesRead = channel.read(readBuffer);
+ while (bytesRead > 0) {
+ readBuffer.clear();
+ bytesRead = channel.read(readBuffer);
+ }
+ } catch (final IOException e) {
+ LOGGER.debug("[{}:{}] Read Channel Discard Failed", remoteAddress, port, e);
+ }
+ }
+
private void writeChannel(final ByteBuffer inputBuffer) throws IOException {
long lastWriteCompleted = System.currentTimeMillis();
@@ -605,19 +582,6 @@ public class SSLSocketChannel implements Closeable {
return Math.min(nanoseconds * 2, BUFFER_FULL_EMPTY_WAIT_NANOS);
}
- private void readChannelDiscard() {
- try {
- final ByteBuffer readBuffer = ByteBuffer.allocate(DISCARD_BUFFER_LENGTH);
- int bytesRead = channel.read(readBuffer);
- while (bytesRead > 0) {
- readBuffer.clear();
- bytesRead = channel.read(readBuffer);
- }
- } catch (final IOException e) {
- LOGGER.debug("[{}:{}] Read Channel Discard Failed", remoteAddress, port, e);
- }
- }
-
private int readApplicationBuffer(final byte[] buffer, final int offset, final int len) {
logOperationBytes("Application Buffer Read Requested", len);
final ByteBuffer appDataBuffer = appDataManager.prepareForRead(len);
diff --git a/nifi-commons/nifi-security-socket-ssl/src/main/java/org/apache/nifi/remote/io/socket/ssl/SSLSocketChannelInputStream.java b/nifi-commons/nifi-security-socket-ssl/src/main/java/org/apache/nifi/remote/io/socket/ssl/SSLSocketChannelInputStream.java
index ca6de85..5bc903f 100644
--- a/nifi-commons/nifi-security-socket-ssl/src/main/java/org/apache/nifi/remote/io/socket/ssl/SSLSocketChannelInputStream.java
+++ b/nifi-commons/nifi-security-socket-ssl/src/main/java/org/apache/nifi/remote/io/socket/ssl/SSLSocketChannelInputStream.java
@@ -58,8 +58,4 @@ public class SSLSocketChannelInputStream extends InputStream {
public int available() throws IOException {
return channel.available();
}
-
- public boolean isDataAvailable() throws IOException {
- return available() > 0;
- }
}
diff --git a/nifi-commons/nifi-security-socket-ssl/src/test/java/org/apache/nifi/remote/io/socket/ssl/SSLSocketChannelTest.java b/nifi-commons/nifi-security-socket-ssl/src/test/java/org/apache/nifi/remote/io/socket/ssl/SSLSocketChannelTest.java
index d770dd9..cb15959 100644
--- a/nifi-commons/nifi-security-socket-ssl/src/test/java/org/apache/nifi/remote/io/socket/ssl/SSLSocketChannelTest.java
+++ b/nifi-commons/nifi-security-socket-ssl/src/test/java/org/apache/nifi/remote/io/socket/ssl/SSLSocketChannelTest.java
@@ -38,9 +38,10 @@ import org.apache.nifi.security.util.SslContextFactory;
import org.apache.nifi.security.util.TemporaryKeyStoreBuilder;
import org.apache.nifi.security.util.TlsConfiguration;
import org.apache.nifi.security.util.TlsPlatform;
-import org.junit.Assume;
-import org.junit.BeforeClass;
-import org.junit.Test;
+import org.junit.jupiter.api.BeforeAll;
+import org.junit.jupiter.api.Test;
+import org.junit.jupiter.api.Timeout;
+import org.junit.jupiter.api.condition.EnabledIf;
import javax.net.ssl.SSLContext;
import javax.net.ssl.SSLEngine;
@@ -55,17 +56,19 @@ import java.nio.charset.Charset;
import java.nio.charset.StandardCharsets;
import java.security.GeneralSecurityException;
import java.util.concurrent.BlockingQueue;
+import java.util.concurrent.CountDownLatch;
import java.util.concurrent.Executor;
import java.util.concurrent.Executors;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.TimeUnit;
import java.util.function.Consumer;
-import static org.junit.Assert.assertEquals;
-import static org.junit.Assert.assertFalse;
-import static org.junit.Assert.assertThrows;
-import static org.junit.Assert.assertTrue;
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertFalse;
+import static org.junit.jupiter.api.Assertions.assertThrows;
+import static org.junit.jupiter.api.Assertions.assertTrue;
+@Timeout(value = 15)
public class SSLSocketChannelTest {
private static final String LOCALHOST = "localhost";
@@ -81,10 +84,10 @@ public class SSLSocketChannelTest {
private static final int CHANNEL_POLL_TIMEOUT = 5000;
- private static final long CHANNEL_SLEEP_BEFORE_READ = 100;
-
private static final int MAX_MESSAGE_LENGTH = 1024;
+ private static final long SHUTDOWN_TIMEOUT = 100;
+
private static final String TLS_1_3 = "TLSv1.3";
private static final String TLS_1_2 = "TLSv1.2";
@@ -97,9 +100,17 @@ public class SSLSocketChannelTest {
private static final int FIRST_BYTE_OFFSET = 1;
+ private static final int SINGLE_COUNT_DOWN = 1;
+
private static SSLContext sslContext;
- @BeforeClass
+ private static final String TLS_1_3_SUPPORTED = "isTls13Supported";
+
+ public static boolean isTls13Supported() {
+ return TlsPlatform.getSupportedProtocols().contains(TLS_1_3);
+ }
+
+ @BeforeAll
public static void setConfiguration() throws GeneralSecurityException {
final TlsConfiguration tlsConfiguration = new TemporaryKeyStoreBuilder().build();
sslContext = SslContextFactory.createSslContext(tlsConfiguration);
@@ -115,54 +126,60 @@ public class SSLSocketChannelTest {
@Test
public void testClientConnectHandshakeFailed() throws IOException {
- assumeProtocolSupported(TLS_1_2);
+ final String enabledProtocol = isTls13Supported() ? TLS_1_3 : TLS_1_2;
+
final EventLoopGroup group = new NioEventLoopGroup(GROUP_THREADS);
try (final SocketChannel socketChannel = SocketChannel.open()) {
final int port = NetworkUtils.getAvailableTcpPort();
- startServer(group, port, TLS_1_2);
+ startServer(group, port, enabledProtocol, getSingleCountDownLatch());
socketChannel.connect(new InetSocketAddress(LOCALHOST, port));
- final SSLEngine sslEngine = createSslEngine(TLS_1_2, CLIENT_CHANNEL);
+ final SSLEngine sslEngine = createSslEngine(enabledProtocol, CLIENT_CHANNEL);
final SSLSocketChannel sslSocketChannel = new SSLSocketChannel(sslEngine, socketChannel);
sslSocketChannel.setTimeout(CHANNEL_FAILURE_TIMEOUT);
- group.shutdownGracefully().syncUninterruptibly();
+ shutdownGroup(group);
assertThrows(SSLException.class, sslSocketChannel::connect);
} finally {
- group.shutdownGracefully().syncUninterruptibly();
+ shutdownGroup(group);
}
}
@Test
public void testClientConnectWriteReadTls12() throws Exception {
- assumeProtocolSupported(TLS_1_2);
assertChannelConnectedWriteReadClosed(TLS_1_2);
}
+ @EnabledIf(TLS_1_3_SUPPORTED)
@Test
public void testClientConnectWriteReadTls13() throws Exception {
- assumeProtocolSupported(TLS_1_3);
assertChannelConnectedWriteReadClosed(TLS_1_3);
}
- @Test(timeout = CHANNEL_TIMEOUT)
+ @Test
+ public void testClientConnectWriteAvailableReadTls12() throws Exception {
+ assertChannelConnectedWriteAvailableRead(TLS_1_2);
+ }
+
+ @EnabledIf(TLS_1_3_SUPPORTED)
+ @Test
+ public void testClientConnectWriteAvailableReadTls13() throws Exception {
+ assertChannelConnectedWriteAvailableRead(TLS_1_3);
+ }
+
+ @Test
public void testServerReadWriteTls12() throws Exception {
- assumeProtocolSupported(TLS_1_2);
assertServerChannelConnectedReadClosed(TLS_1_2);
}
- @Test(timeout = CHANNEL_TIMEOUT)
+ @EnabledIf(TLS_1_3_SUPPORTED)
+ @Test
public void testServerReadWriteTls13() throws Exception {
- assumeProtocolSupported(TLS_1_3);
assertServerChannelConnectedReadClosed(TLS_1_3);
}
- private void assumeProtocolSupported(final String protocol) {
- Assume.assumeTrue(String.format("Protocol [%s] not supported", protocol), TlsPlatform.getSupportedProtocols().contains(protocol));
- }
-
private void assertServerChannelConnectedReadClosed(final String enabledProtocol) throws IOException, InterruptedException {
final int port = NetworkUtils.getAvailableTcpPort();
final ServerSocketChannel serverSocketChannel = ServerSocketChannel.open();
@@ -194,67 +211,100 @@ public class SSLSocketChannelTest {
channel.writeAndFlush(MESSAGE).syncUninterruptibly();
final String messageRead = queue.poll(CHANNEL_POLL_TIMEOUT, TimeUnit.MILLISECONDS);
- assertEquals("Message not matched", MESSAGE, messageRead);
+ assertEquals(MESSAGE, messageRead, "Message not matched");
} finally {
channel.close();
}
} finally {
- group.shutdownGracefully().syncUninterruptibly();
+ shutdownGroup(group);
serverSocketChannel.close();
}
}
private void assertChannelConnectedWriteReadClosed(final String enabledProtocol) throws IOException {
- processClientSslSocketChannel(enabledProtocol, (sslSocketChannel -> {
+ final CountDownLatch countDownLatch = getSingleCountDownLatch();
+ processClientSslSocketChannel(enabledProtocol, countDownLatch, (sslSocketChannel -> {
try {
sslSocketChannel.connect();
- assertFalse("Channel closed", sslSocketChannel.isClosed());
+ assertFalse(sslSocketChannel.isClosed());
- assertChannelWriteRead(sslSocketChannel);
+ assertChannelWriteRead(sslSocketChannel, countDownLatch);
sslSocketChannel.close();
- assertTrue("Channel not closed", sslSocketChannel.isClosed());
+ assertTrue(sslSocketChannel.isClosed());
} catch (final IOException e) {
throw new UncheckedIOException(String.format("Channel Failed for %s", enabledProtocol), e);
}
}));
}
- private void assertChannelWriteRead(final SSLSocketChannel sslSocketChannel) throws IOException {
- sslSocketChannel.write(MESSAGE_BYTES);
-
- while (sslSocketChannel.available() == 0) {
+ private void assertChannelConnectedWriteAvailableRead(final String enabledProtocol) throws IOException {
+ final CountDownLatch countDownLatch = getSingleCountDownLatch();
+ processClientSslSocketChannel(enabledProtocol, countDownLatch, (sslSocketChannel -> {
try {
- TimeUnit.MILLISECONDS.sleep(CHANNEL_SLEEP_BEFORE_READ);
- } catch (final InterruptedException e) {
- throw new RuntimeException(e);
+ sslSocketChannel.connect();
+ assertFalse(sslSocketChannel.isClosed());
+
+ assertChannelWriteAvailableRead(sslSocketChannel, countDownLatch);
+
+ sslSocketChannel.close();
+ assertTrue(sslSocketChannel.isClosed());
+ } catch (final IOException e) {
+ throw new UncheckedIOException(String.format("Channel Failed for %s", enabledProtocol), e);
}
+ }));
+ }
+
+ private void assertChannelWriteAvailableRead(final SSLSocketChannel sslSocketChannel, final CountDownLatch countDownLatch) throws IOException {
+ sslSocketChannel.write(MESSAGE_BYTES);
+ sslSocketChannel.available();
+ awaitCountDownLatch(countDownLatch);
+ assetMessageRead(sslSocketChannel);
+ }
+
+ private void assertChannelWriteRead(final SSLSocketChannel sslSocketChannel, final CountDownLatch countDownLatch) throws IOException {
+ sslSocketChannel.write(MESSAGE_BYTES);
+ awaitCountDownLatch(countDownLatch);
+ assetMessageRead(sslSocketChannel);
+ }
+
+ private void awaitCountDownLatch(final CountDownLatch countDownLatch) throws IOException {
+ try {
+ countDownLatch.await();
+ } catch (final InterruptedException e) {
+ throw new IOException("Count Down Interrupted", e);
}
+ }
+ private void assetMessageRead(final SSLSocketChannel sslSocketChannel) throws IOException {
final byte firstByteRead = (byte) sslSocketChannel.read();
- assertEquals("Channel Message first byte not matched", MESSAGE_BYTES[0], firstByteRead);
+ assertEquals(MESSAGE_BYTES[0], firstByteRead, "Channel Message first byte not matched");
+
+ final int available = sslSocketChannel.available();
+ final int availableExpected = MESSAGE_BYTES.length - FIRST_BYTE_OFFSET;
+ assertEquals(availableExpected, available, "Available Bytes not matched");
final byte[] messageBytes = new byte[MESSAGE_BYTES.length];
messageBytes[0] = firstByteRead;
final int messageBytesRead = sslSocketChannel.read(messageBytes, FIRST_BYTE_OFFSET, messageBytes.length);
- assertEquals("Channel Message Bytes Read not matched", messageBytes.length - FIRST_BYTE_OFFSET, messageBytesRead);
+ assertEquals(messageBytes.length - FIRST_BYTE_OFFSET, messageBytesRead, "Channel Message Bytes Read not matched");
final String message = new String(messageBytes, MESSAGE_CHARSET);
- assertEquals("Channel Message not matched", MESSAGE, message);
+ assertEquals(MESSAGE, message, "Message not matched");
}
- private void processClientSslSocketChannel(final String enabledProtocol, final Consumer<SSLSocketChannel> channelConsumer) throws IOException {
+ private void processClientSslSocketChannel(final String enabledProtocol, final CountDownLatch countDownLatch, final Consumer<SSLSocketChannel> channelConsumer) throws IOException {
final EventLoopGroup group = new NioEventLoopGroup(GROUP_THREADS);
try {
final int port = NetworkUtils.getAvailableTcpPort();
- startServer(group, port, enabledProtocol);
+ startServer(group, port, enabledProtocol, countDownLatch);
final SSLSocketChannel sslSocketChannel = new SSLSocketChannel(sslContext, LOCALHOST, port, null, CLIENT_CHANNEL);
sslSocketChannel.setTimeout(CHANNEL_TIMEOUT);
channelConsumer.accept(sslSocketChannel);
} finally {
- group.shutdownGracefully().syncUninterruptibly();
+ shutdownGroup(group);
}
}
@@ -273,7 +323,7 @@ public class SSLSocketChannelTest {
return bootstrap.connect(LOCALHOST, port).syncUninterruptibly().channel();
}
- private void startServer(final EventLoopGroup group, final int port, final String enabledProtocol) {
+ private void startServer(final EventLoopGroup group, final int port, final String enabledProtocol, final CountDownLatch countDownLatch) {
final ServerBootstrap bootstrap = new ServerBootstrap();
bootstrap.group(group);
bootstrap.channel(NioServerSocketChannel.class);
@@ -287,6 +337,7 @@ public class SSLSocketChannelTest {
@Override
protected void channelRead0(ChannelHandlerContext channelHandlerContext, String s) throws Exception {
channelHandlerContext.channel().writeAndFlush(MESSAGE).sync();
+ countDownLatch.countDown();
}
});
}
@@ -309,4 +360,12 @@ public class SSLSocketChannelTest {
pipeline.addLast(new StringDecoder());
pipeline.addLast(new StringEncoder());
}
+
+ private void shutdownGroup(final EventLoopGroup group) {
+ group.shutdownGracefully(SHUTDOWN_TIMEOUT, SHUTDOWN_TIMEOUT, TimeUnit.MILLISECONDS).syncUninterruptibly();
+ }
+
+ private CountDownLatch getSingleCountDownLatch() {
+ return new CountDownLatch(SINGLE_COUNT_DOWN);
+ }
}