You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@kafka.apache.org by ij...@apache.org on 2017/09/22 19:58:29 UTC
kafka git commit: KAFKA-5920;
Handle SSL handshake failures as authentication exceptions
Repository: kafka
Updated Branches:
refs/heads/trunk e554dc518 -> d60f011d7
KAFKA-5920; Handle SSL handshake failures as authentication exceptions
1. Propagate `SSLException` as `SslAuthenticationException` to enable clients to report these and avoid retries
2. Updates to `SslTransportLayer` to process bytes received even if end-of-stream
3. Some tidy up of authentication handling
4. Report exceptions in SaslClientAuthenticator as AuthenticationExceptions
Author: Rajini Sivaram <ra...@googlemail.com>
Reviewers: Ismael Juma <is...@juma.me.uk>
Closes #3918 from rajinisivaram/KAFKA-5920-SSL-handshake-failure
Project: http://git-wip-us.apache.org/repos/asf/kafka/repo
Commit: http://git-wip-us.apache.org/repos/asf/kafka/commit/d60f011d
Tree: http://git-wip-us.apache.org/repos/asf/kafka/tree/d60f011d
Diff: http://git-wip-us.apache.org/repos/asf/kafka/diff/d60f011d
Branch: refs/heads/trunk
Commit: d60f011d77ce80a44b02d43bf0889a50a8797dcd
Parents: e554dc5
Author: Rajini Sivaram <ra...@googlemail.com>
Authored: Fri Sep 22 20:26:46 2017 +0100
Committer: Ismael Juma <is...@juma.me.uk>
Committed: Fri Sep 22 20:29:25 2017 +0100
----------------------------------------------------------------------
.../common/errors/AuthenticationException.java | 3 +
.../errors/SslAuthenticationException.java | 44 +++
.../kafka/common/network/Authenticator.java | 17 +-
.../kafka/common/network/KafkaChannel.java | 31 +-
.../common/network/PlaintextChannelBuilder.java | 7 -
.../apache/kafka/common/network/Selector.java | 3 +
.../kafka/common/network/SslChannelBuilder.java | 8 -
.../kafka/common/network/SslTransportLayer.java | 296 +++++++++++-------
.../kafka/common/network/TransportLayer.java | 12 +-
.../authenticator/SaslClientAuthenticator.java | 31 +-
.../authenticator/SaslServerAuthenticator.java | 16 +-
.../kafka/common/network/NioEchoServer.java | 12 +-
.../common/network/SslTransportLayerTest.java | 305 +++++++++++++------
13 files changed, 506 insertions(+), 279 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/kafka/blob/d60f011d/clients/src/main/java/org/apache/kafka/common/errors/AuthenticationException.java
----------------------------------------------------------------------
diff --git a/clients/src/main/java/org/apache/kafka/common/errors/AuthenticationException.java b/clients/src/main/java/org/apache/kafka/common/errors/AuthenticationException.java
index c56ac88..f6458c6 100644
--- a/clients/src/main/java/org/apache/kafka/common/errors/AuthenticationException.java
+++ b/clients/src/main/java/org/apache/kafka/common/errors/AuthenticationException.java
@@ -16,6 +16,8 @@
*/
package org.apache.kafka.common.errors;
+import javax.net.ssl.SSLException;
+
/**
* This exception indicates that SASL authentication has failed.
* On authentication failure, clients abort the operation requested and raise one
@@ -27,6 +29,7 @@ package org.apache.kafka.common.errors;
* is not supported on the broker.</li>
* <li>{@link IllegalSaslStateException} if an unexpected request is received on during SASL
* handshake. This could be due to misconfigured security protocol.</li>
+ * <li>{@link SslAuthenticationException} if SSL handshake failed due to any {@link SSLException}.
* </ul>
*/
public class AuthenticationException extends ApiException {
http://git-wip-us.apache.org/repos/asf/kafka/blob/d60f011d/clients/src/main/java/org/apache/kafka/common/errors/SslAuthenticationException.java
----------------------------------------------------------------------
diff --git a/clients/src/main/java/org/apache/kafka/common/errors/SslAuthenticationException.java b/clients/src/main/java/org/apache/kafka/common/errors/SslAuthenticationException.java
new file mode 100644
index 0000000..3cdbf2a
--- /dev/null
+++ b/clients/src/main/java/org/apache/kafka/common/errors/SslAuthenticationException.java
@@ -0,0 +1,44 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.kafka.common.errors;
+
+import javax.net.ssl.SSLException;
+
+/**
+ * This exception indicates that SSL handshake has failed. See {@link #getCause()}
+ * for the {@link SSLException} that caused this failure.
+ * <p>
+ * SSL handshake failures in clients may indicate client authentication
+ * failure due to untrusted certificates if server is configured to request
+ * client certificates. Handshake failures could also indicate misconfigured
+ * security including protocol/cipher suite mismatch, server certificate
+ * authentication failure or server host name verification failure.
+ * </p>
+ */
+public class SslAuthenticationException extends AuthenticationException {
+
+ private static final long serialVersionUID = 1L;
+
+ public SslAuthenticationException(String message) {
+ super(message);
+ }
+
+ public SslAuthenticationException(String message, Throwable cause) {
+ super(message, cause);
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/kafka/blob/d60f011d/clients/src/main/java/org/apache/kafka/common/network/Authenticator.java
----------------------------------------------------------------------
diff --git a/clients/src/main/java/org/apache/kafka/common/network/Authenticator.java b/clients/src/main/java/org/apache/kafka/common/network/Authenticator.java
index fa1123e..4e2e727 100644
--- a/clients/src/main/java/org/apache/kafka/common/network/Authenticator.java
+++ b/clients/src/main/java/org/apache/kafka/common/network/Authenticator.java
@@ -16,7 +16,7 @@
*/
package org.apache.kafka.common.network;
-import org.apache.kafka.common.protocol.Errors;
+import org.apache.kafka.common.errors.AuthenticationException;
import org.apache.kafka.common.security.auth.KafkaPrincipal;
import java.io.Closeable;
@@ -28,15 +28,14 @@ import java.io.IOException;
public interface Authenticator extends Closeable {
/**
* Implements any authentication mechanism. Use transportLayer to read or write tokens.
- * If no further authentication needs to be done returns.
+ * For security protocols PLAINTEXT and SSL, this is a no-op since no further authentication
+ * needs to be done. For SASL_PLAINTEXT and SASL_SSL, this performs the SASL authentication.
+ *
+ * @throws AuthenticationException if authentication fails due to invalid credentials or
+ * other security configuration errors
+ * @throws IOException if read/write fails due to an I/O error
*/
- void authenticate() throws IOException;
-
- /**
- * Returns the first error encountered during authentication
- * @return authentication error if authentication failed, Errors.NONE otherwise
- */
- Errors error();
+ void authenticate() throws AuthenticationException, IOException;
/**
* Returns Principal using PrincipalBuilder
http://git-wip-us.apache.org/repos/asf/kafka/blob/d60f011d/clients/src/main/java/org/apache/kafka/common/network/KafkaChannel.java
----------------------------------------------------------------------
diff --git a/clients/src/main/java/org/apache/kafka/common/network/KafkaChannel.java b/clients/src/main/java/org/apache/kafka/common/network/KafkaChannel.java
index 24cd9cf..f07035a 100644
--- a/clients/src/main/java/org/apache/kafka/common/network/KafkaChannel.java
+++ b/clients/src/main/java/org/apache/kafka/common/network/KafkaChannel.java
@@ -69,26 +69,21 @@ public class KafkaChannel {
}
/**
- * Does handshake of transportLayer and authentication using configured authenticator
+ * Does handshake of transportLayer and authentication using configured authenticator.
+ * For SSL with client authentication enabled, {@link TransportLayer#handshake()} performs
+ * authentication. For SASL, authentication is performed by {@link Authenticator#authenticate()}.
*/
- public void prepare() throws IOException {
- if (!transportLayer.ready())
- transportLayer.handshake();
- if (transportLayer.ready() && !authenticator.complete()) {
- try {
+ public void prepare() throws AuthenticationException, IOException {
+ try {
+ if (!transportLayer.ready())
+ transportLayer.handshake();
+ if (transportLayer.ready() && !authenticator.complete())
authenticator.authenticate();
- } catch (AuthenticationException e) {
- switch (authenticator.error()) {
- case SASL_AUTHENTICATION_FAILED:
- case ILLEGAL_SASL_STATE:
- case UNSUPPORTED_SASL_MECHANISM:
- state = new ChannelState(ChannelState.State.AUTHENTICATION_FAILED, e);
- break;
- default:
- // Other errors are handled as network exceptions in Selector
- }
- throw e;
- }
+ } catch (AuthenticationException e) {
+ // Clients are notified of authentication exceptions to enable operations to be terminated
+ // without retries. Other errors are handled as network exceptions in Selector.
+ state = new ChannelState(ChannelState.State.AUTHENTICATION_FAILED, e);
+ throw e;
}
if (ready())
state = ChannelState.READY;
http://git-wip-us.apache.org/repos/asf/kafka/blob/d60f011d/clients/src/main/java/org/apache/kafka/common/network/PlaintextChannelBuilder.java
----------------------------------------------------------------------
diff --git a/clients/src/main/java/org/apache/kafka/common/network/PlaintextChannelBuilder.java b/clients/src/main/java/org/apache/kafka/common/network/PlaintextChannelBuilder.java
index 95fd903..c0d1059 100644
--- a/clients/src/main/java/org/apache/kafka/common/network/PlaintextChannelBuilder.java
+++ b/clients/src/main/java/org/apache/kafka/common/network/PlaintextChannelBuilder.java
@@ -18,7 +18,6 @@ package org.apache.kafka.common.network;
import org.apache.kafka.common.KafkaException;
import org.apache.kafka.common.memory.MemoryPool;
-import org.apache.kafka.common.protocol.Errors;
import org.apache.kafka.common.security.auth.KafkaPrincipal;
import org.apache.kafka.common.security.auth.KafkaPrincipalBuilder;
import org.apache.kafka.common.security.auth.PlaintextAuthenticationContext;
@@ -80,12 +79,6 @@ public class PlaintextChannelBuilder implements ChannelBuilder {
}
@Override
- public Errors error() {
- // PLAINTEXT never fails authentication
- return Errors.NONE;
- }
-
- @Override
public void close() {
if (principalBuilder instanceof Closeable)
Utils.closeQuietly((Closeable) principalBuilder, "principal builder");
http://git-wip-us.apache.org/repos/asf/kafka/blob/d60f011d/clients/src/main/java/org/apache/kafka/common/network/Selector.java
----------------------------------------------------------------------
diff --git a/clients/src/main/java/org/apache/kafka/common/network/Selector.java b/clients/src/main/java/org/apache/kafka/common/network/Selector.java
index 7977879..b753745 100644
--- a/clients/src/main/java/org/apache/kafka/common/network/Selector.java
+++ b/clients/src/main/java/org/apache/kafka/common/network/Selector.java
@@ -42,6 +42,7 @@ import org.apache.kafka.common.memory.MemoryPool;
import org.apache.kafka.common.metrics.Measurable;
import org.apache.kafka.common.metrics.MetricConfig;
import org.apache.kafka.common.MetricName;
+import org.apache.kafka.common.errors.AuthenticationException;
import org.apache.kafka.common.metrics.Metrics;
import org.apache.kafka.common.metrics.Sensor;
import org.apache.kafka.common.metrics.stats.Avg;
@@ -485,6 +486,8 @@ public class Selector implements Selectable, AutoCloseable {
String desc = channel.socketDescription();
if (e instanceof IOException)
log.debug("Connection with {} disconnected", desc, e);
+ else if (e instanceof AuthenticationException) // will be logged later as error by clients
+ log.debug("Connection with {} disconnected due to authentication exception", desc, e);
else
log.warn("Unexpected error from {}; closing connection", desc, e);
close(channel, true);
http://git-wip-us.apache.org/repos/asf/kafka/blob/d60f011d/clients/src/main/java/org/apache/kafka/common/network/SslChannelBuilder.java
----------------------------------------------------------------------
diff --git a/clients/src/main/java/org/apache/kafka/common/network/SslChannelBuilder.java b/clients/src/main/java/org/apache/kafka/common/network/SslChannelBuilder.java
index 80b9e9a..9519e58 100644
--- a/clients/src/main/java/org/apache/kafka/common/network/SslChannelBuilder.java
+++ b/clients/src/main/java/org/apache/kafka/common/network/SslChannelBuilder.java
@@ -18,7 +18,6 @@ package org.apache.kafka.common.network;
import org.apache.kafka.common.KafkaException;
import org.apache.kafka.common.memory.MemoryPool;
-import org.apache.kafka.common.protocol.Errors;
import org.apache.kafka.common.security.auth.KafkaPrincipal;
import org.apache.kafka.common.security.auth.KafkaPrincipalBuilder;
import org.apache.kafka.common.security.auth.SslAuthenticationContext;
@@ -158,12 +157,5 @@ public class SslChannelBuilder implements ChannelBuilder {
public boolean complete() {
return true;
}
-
- @Override
- public Errors error() {
- // SSL authentication failures are currently not propagated to clients
- return Errors.NONE;
- }
-
}
}
http://git-wip-us.apache.org/repos/asf/kafka/blob/d60f011d/clients/src/main/java/org/apache/kafka/common/network/SslTransportLayer.java
----------------------------------------------------------------------
diff --git a/clients/src/main/java/org/apache/kafka/common/network/SslTransportLayer.java b/clients/src/main/java/org/apache/kafka/common/network/SslTransportLayer.java
index 3cd0114..f5e1e70 100644
--- a/clients/src/main/java/org/apache/kafka/common/network/SslTransportLayer.java
+++ b/clients/src/main/java/org/apache/kafka/common/network/SslTransportLayer.java
@@ -34,6 +34,7 @@ import javax.net.ssl.SSLHandshakeException;
import javax.net.ssl.SSLSession;
import javax.net.ssl.SSLPeerUnverifiedException;
+import org.apache.kafka.common.errors.SslAuthenticationException;
import org.apache.kafka.common.security.auth.KafkaPrincipal;
import org.apache.kafka.common.utils.Utils;
import org.slf4j.Logger;
@@ -44,6 +45,14 @@ import org.slf4j.LoggerFactory;
*/
public class SslTransportLayer implements TransportLayer {
private static final Logger log = LoggerFactory.getLogger(SslTransportLayer.class);
+
+ private enum State {
+ HANDSHAKE,
+ HANDSHAKE_FAILED,
+ READY,
+ CLOSING
+ }
+
private final String channelId;
private final SSLEngine sslEngine;
private final SelectionKey key;
@@ -52,8 +61,8 @@ public class SslTransportLayer implements TransportLayer {
private HandshakeStatus handshakeStatus;
private SSLEngineResult handshakeResult;
- private boolean handshakeComplete = false;
- private boolean closing = false;
+ private State state;
+ private SslAuthenticationException handshakeException;
private ByteBuffer netReadBuffer;
private ByteBuffer netWriteBuffer;
private ByteBuffer appReadBuffer;
@@ -89,8 +98,7 @@ public class SslTransportLayer implements TransportLayer {
netWriteBuffer.limit(0);
netReadBuffer.position(0);
netReadBuffer.limit(0);
- handshakeComplete = false;
- closing = false;
+ state = State.HANDSHAKE;
//initiate handshake
sslEngine.beginHandshake();
handshakeStatus = sslEngine.getHandshakeStatus();
@@ -98,7 +106,7 @@ public class SslTransportLayer implements TransportLayer {
@Override
public boolean ready() {
- return handshakeComplete;
+ return state == State.READY;
}
/**
@@ -141,8 +149,8 @@ public class SslTransportLayer implements TransportLayer {
*/
@Override
public void close() throws IOException {
- if (closing) return;
- closing = true;
+ if (state == State.CLOSING) return;
+ state = State.CLOSING;
sslEngine.closeOutbound();
try {
if (isConnected()) {
@@ -183,12 +191,22 @@ public class SslTransportLayer implements TransportLayer {
}
/**
- * Flushes the buffer to the network, non blocking
+ * Reads available bytes from socket channel to `netReadBuffer`.
+ * Visible for testing.
+ * @return number of bytes read
+ */
+ protected int readFromSocketChannel() throws IOException {
+ return socketChannel.read(netReadBuffer);
+ }
+
+ /**
+ * Flushes the buffer to the network, non blocking.
+ * Visible for testing.
* @param buf ByteBuffer
* @return boolean true if the buffer has been emptied out, false otherwise
* @throws IOException
*/
- private boolean flush(ByteBuffer buf) throws IOException {
+ protected boolean flush(ByteBuffer buf) throws IOException {
int remaining = buf.remaining();
if (remaining > 0) {
int written = socketChannel.write(buf);
@@ -217,101 +235,137 @@ public class SslTransportLayer implements TransportLayer {
* | unwrap() | Finished | FINISHED |
* +-------------+----------------------------------+-------------+
*
- * @throws IOException
+ * @throws IOException if read/write fails
+ * @throws SslAuthenticationException if handshake fails with an {@link SSLException}
*/
@Override
public void handshake() throws IOException {
+ // Reset state to support renegotiation. This can be removed if renegotiation support is removed.
+ if (state == State.READY)
+ state = State.HANDSHAKE;
+
+ int read = 0;
+ try {
+ // Read any available bytes before attempting any writes to ensure that handshake failures
+ // reported by the peer are processed even if writes fail (since peer closes connection
+ // if handshake fails)
+ if (key.isReadable())
+ read = readFromSocketChannel();
+
+ doHandshake();
+ } catch (SSLException e) {
+ handshakeFailure(e, true);
+ } catch (IOException e) {
+ maybeThrowSslAuthenticationException();
+
+ // this exception could be due to a write. If there is data available to unwrap,
+ // process the data so that any SSLExceptions are reported
+ if (handshakeStatus == HandshakeStatus.NEED_UNWRAP && netReadBuffer.position() > 0) {
+ try {
+ handshakeUnwrap(false);
+ } catch (SSLException e1) {
+ handshakeFailure(e1, false);
+ }
+ }
+ // If we get here, this is not a handshake failure, throw the original IOException
+ throw e;
+ }
+
+ // Read from socket failed, so throw any pending handshake exception or EOF exception.
+ if (read == -1) {
+ maybeThrowSslAuthenticationException();
+ throw new EOFException("EOF during handshake, handshake status is " + handshakeStatus);
+ }
+ }
+
+ private void doHandshake() throws IOException {
boolean read = key.isReadable();
boolean write = key.isWritable();
- handshakeComplete = false;
handshakeStatus = sslEngine.getHandshakeStatus();
if (!flush(netWriteBuffer)) {
key.interestOps(key.interestOps() | SelectionKey.OP_WRITE);
return;
}
- try {
- switch (handshakeStatus) {
- case NEED_TASK:
- log.trace("SSLHandshake NEED_TASK channelId {}, appReadBuffer pos {}, netReadBuffer pos {}, netWriteBuffer pos {}",
- channelId, appReadBuffer.position(), netReadBuffer.position(), netWriteBuffer.position());
- handshakeStatus = runDelegatedTasks();
+ // Throw any pending handshake exception since `netWriteBuffer` has been flushed
+ maybeThrowSslAuthenticationException();
+
+ switch (handshakeStatus) {
+ case NEED_TASK:
+ log.trace("SSLHandshake NEED_TASK channelId {}, appReadBuffer pos {}, netReadBuffer pos {}, netWriteBuffer pos {}",
+ channelId, appReadBuffer.position(), netReadBuffer.position(), netWriteBuffer.position());
+ handshakeStatus = runDelegatedTasks();
+ break;
+ case NEED_WRAP:
+ log.trace("SSLHandshake NEED_WRAP channelId {}, appReadBuffer pos {}, netReadBuffer pos {}, netWriteBuffer pos {}",
+ channelId, appReadBuffer.position(), netReadBuffer.position(), netWriteBuffer.position());
+ handshakeResult = handshakeWrap(write);
+ if (handshakeResult.getStatus() == Status.BUFFER_OVERFLOW) {
+ int currentNetWriteBufferSize = netWriteBufferSize();
+ netWriteBuffer.compact();
+ netWriteBuffer = Utils.ensureCapacity(netWriteBuffer, currentNetWriteBufferSize);
+ netWriteBuffer.flip();
+ if (netWriteBuffer.limit() >= currentNetWriteBufferSize) {
+ throw new IllegalStateException("Buffer overflow when available data size (" + netWriteBuffer.limit() +
+ ") >= network buffer size (" + currentNetWriteBufferSize + ")");
+ }
+ } else if (handshakeResult.getStatus() == Status.BUFFER_UNDERFLOW) {
+ throw new IllegalStateException("Should not have received BUFFER_UNDERFLOW during handshake WRAP.");
+ } else if (handshakeResult.getStatus() == Status.CLOSED) {
+ throw new EOFException();
+ }
+ log.trace("SSLHandshake NEED_WRAP channelId {}, handshakeResult {}, appReadBuffer pos {}, netReadBuffer pos {}, netWriteBuffer pos {}",
+ channelId, handshakeResult, appReadBuffer.position(), netReadBuffer.position(), netWriteBuffer.position());
+ //if handshake status is not NEED_UNWRAP or unable to flush netWriteBuffer contents
+ //we will break here otherwise we can do need_unwrap in the same call.
+ if (handshakeStatus != HandshakeStatus.NEED_UNWRAP || !flush(netWriteBuffer)) {
+ key.interestOps(key.interestOps() | SelectionKey.OP_WRITE);
break;
- case NEED_WRAP:
- log.trace("SSLHandshake NEED_WRAP channelId {}, appReadBuffer pos {}, netReadBuffer pos {}, netWriteBuffer pos {}",
- channelId, appReadBuffer.position(), netReadBuffer.position(), netWriteBuffer.position());
- handshakeResult = handshakeWrap(write);
+ }
+ case NEED_UNWRAP:
+ log.trace("SSLHandshake NEED_UNWRAP channelId {}, appReadBuffer pos {}, netReadBuffer pos {}, netWriteBuffer pos {}",
+ channelId, appReadBuffer.position(), netReadBuffer.position(), netWriteBuffer.position());
+ do {
+ handshakeResult = handshakeUnwrap(read);
if (handshakeResult.getStatus() == Status.BUFFER_OVERFLOW) {
- int currentNetWriteBufferSize = netWriteBufferSize();
- netWriteBuffer.compact();
- netWriteBuffer = Utils.ensureCapacity(netWriteBuffer, currentNetWriteBufferSize);
- netWriteBuffer.flip();
- if (netWriteBuffer.limit() >= currentNetWriteBufferSize) {
- throw new IllegalStateException("Buffer overflow when available data size (" + netWriteBuffer.limit() +
- ") >= network buffer size (" + currentNetWriteBufferSize + ")");
+ int currentAppBufferSize = applicationBufferSize();
+ appReadBuffer = Utils.ensureCapacity(appReadBuffer, currentAppBufferSize);
+ if (appReadBuffer.position() > currentAppBufferSize) {
+ throw new IllegalStateException("Buffer underflow when available data size (" + appReadBuffer.position() +
+ ") > packet buffer size (" + currentAppBufferSize + ")");
}
- } else if (handshakeResult.getStatus() == Status.BUFFER_UNDERFLOW) {
- throw new IllegalStateException("Should not have received BUFFER_UNDERFLOW during handshake WRAP.");
- } else if (handshakeResult.getStatus() == Status.CLOSED) {
- throw new EOFException();
}
- log.trace("SSLHandshake NEED_WRAP channelId {}, handshakeResult {}, appReadBuffer pos {}, netReadBuffer pos {}, netWriteBuffer pos {}",
- channelId, handshakeResult, appReadBuffer.position(), netReadBuffer.position(), netWriteBuffer.position());
- //if handshake status is not NEED_UNWRAP or unable to flush netWriteBuffer contents
- //we will break here otherwise we can do need_unwrap in the same call.
- if (handshakeStatus != HandshakeStatus.NEED_UNWRAP || !flush(netWriteBuffer)) {
- key.interestOps(key.interestOps() | SelectionKey.OP_WRITE);
- break;
- }
- case NEED_UNWRAP:
- log.trace("SSLHandshake NEED_UNWRAP channelId {}, appReadBuffer pos {}, netReadBuffer pos {}, netWriteBuffer pos {}",
- channelId, appReadBuffer.position(), netReadBuffer.position(), netWriteBuffer.position());
- do {
- handshakeResult = handshakeUnwrap(read);
- if (handshakeResult.getStatus() == Status.BUFFER_OVERFLOW) {
- int currentAppBufferSize = applicationBufferSize();
- appReadBuffer = Utils.ensureCapacity(appReadBuffer, currentAppBufferSize);
- if (appReadBuffer.position() > currentAppBufferSize) {
- throw new IllegalStateException("Buffer underflow when available data size (" + appReadBuffer.position() +
- ") > packet buffer size (" + currentAppBufferSize + ")");
- }
- }
- } while (handshakeResult.getStatus() == Status.BUFFER_OVERFLOW);
- if (handshakeResult.getStatus() == Status.BUFFER_UNDERFLOW) {
- int currentNetReadBufferSize = netReadBufferSize();
- netReadBuffer = Utils.ensureCapacity(netReadBuffer, currentNetReadBufferSize);
- if (netReadBuffer.position() >= currentNetReadBufferSize) {
- throw new IllegalStateException("Buffer underflow when there is available data");
- }
- } else if (handshakeResult.getStatus() == Status.CLOSED) {
- throw new EOFException("SSL handshake status CLOSED during handshake UNWRAP");
+ } while (handshakeResult.getStatus() == Status.BUFFER_OVERFLOW);
+ if (handshakeResult.getStatus() == Status.BUFFER_UNDERFLOW) {
+ int currentNetReadBufferSize = netReadBufferSize();
+ netReadBuffer = Utils.ensureCapacity(netReadBuffer, currentNetReadBufferSize);
+ if (netReadBuffer.position() >= currentNetReadBufferSize) {
+ throw new IllegalStateException("Buffer underflow when there is available data");
}
- log.trace("SSLHandshake NEED_UNWRAP channelId {}, handshakeResult {}, appReadBuffer pos {}, netReadBuffer pos {}, netWriteBuffer pos {}",
- channelId, handshakeResult, appReadBuffer.position(), netReadBuffer.position(), netWriteBuffer.position());
-
- //if handshakeStatus completed than fall-through to finished status.
- //after handshake is finished there is no data left to read/write in socketChannel.
- //so the selector won't invoke this channel if we don't go through the handshakeFinished here.
- if (handshakeStatus != HandshakeStatus.FINISHED) {
- if (handshakeStatus == HandshakeStatus.NEED_WRAP) {
- key.interestOps(key.interestOps() | SelectionKey.OP_WRITE);
- } else if (handshakeStatus == HandshakeStatus.NEED_UNWRAP) {
- key.interestOps(key.interestOps() & ~SelectionKey.OP_WRITE);
- }
- break;
+ } else if (handshakeResult.getStatus() == Status.CLOSED) {
+ throw new EOFException("SSL handshake status CLOSED during handshake UNWRAP");
+ }
+ log.trace("SSLHandshake NEED_UNWRAP channelId {}, handshakeResult {}, appReadBuffer pos {}, netReadBuffer pos {}, netWriteBuffer pos {}",
+ channelId, handshakeResult, appReadBuffer.position(), netReadBuffer.position(), netWriteBuffer.position());
+
+ //if handshakeStatus completed than fall-through to finished status.
+ //after handshake is finished there is no data left to read/write in socketChannel.
+ //so the selector won't invoke this channel if we don't go through the handshakeFinished here.
+ if (handshakeStatus != HandshakeStatus.FINISHED) {
+ if (handshakeStatus == HandshakeStatus.NEED_WRAP) {
+ key.interestOps(key.interestOps() | SelectionKey.OP_WRITE);
+ } else if (handshakeStatus == HandshakeStatus.NEED_UNWRAP) {
+ key.interestOps(key.interestOps() & ~SelectionKey.OP_WRITE);
}
- case FINISHED:
- handshakeFinished();
- break;
- case NOT_HANDSHAKING:
- handshakeFinished();
break;
- default:
- throw new IllegalStateException(String.format("Unexpected status [%s]", handshakeStatus));
- }
-
- } catch (SSLException e) {
- handshakeFailure();
- throw e;
+ }
+ case FINISHED:
+ handshakeFinished();
+ break;
+ case NOT_HANDSHAKING:
+ handshakeFinished();
+ break;
+ default:
+ throw new IllegalStateException(String.format("Unexpected status [%s]", handshakeStatus));
}
}
@@ -346,12 +400,12 @@ public class SslTransportLayer implements TransportLayer {
// It can move from FINISHED status to NOT_HANDSHAKING after the handshake is completed.
// Hence we also need to check handshakeResult.getHandshakeStatus() if the handshake finished or not
if (handshakeResult.getHandshakeStatus() == HandshakeStatus.FINISHED) {
- //we are complete if we have delivered the last package
- handshakeComplete = !netWriteBuffer.hasRemaining();
+ //we are complete if we have delivered the last packet
//remove OP_WRITE if we are complete, otherwise we still have data to write
- if (!handshakeComplete)
+ if (netWriteBuffer.hasRemaining())
key.interestOps(key.interestOps() | SelectionKey.OP_WRITE);
else {
+ state = State.READY;
key.interestOps(key.interestOps() & ~SelectionKey.OP_WRITE);
SSLSession session = sslEngine.getSession();
log.debug("SSL handshake completed successfully with peerHost '{}' peerPort {} peerPrincipal '{}' cipherSuite '{}'",
@@ -400,10 +454,9 @@ public class SslTransportLayer implements TransportLayer {
private SSLEngineResult handshakeUnwrap(boolean doRead) throws IOException {
log.trace("SSLHandshake handshakeUnwrap {}", channelId);
SSLEngineResult result;
- if (doRead) {
- int read = socketChannel.read(netReadBuffer);
- if (read == -1) throw new EOFException("EOF during handshake.");
- }
+ int read = 0;
+ if (doRead)
+ read = readFromSocketChannel();
boolean cont;
do {
//prepare the buffer with the incoming data
@@ -420,6 +473,11 @@ public class SslTransportLayer implements TransportLayer {
log.trace("SSLHandshake handshakeUnwrap: handshakeStatus {} status {}", handshakeStatus, result.getStatus());
} while (netReadBuffer.position() != 0 && cont);
+ // Throw EOF exception for failed read after processing already received data
+ // so that handshake failures are reported correctly
+ if (read == -1)
+ throw new EOFException("EOF during handshake, handshake status is " + handshakeStatus);
+
return result;
}
@@ -429,27 +487,27 @@ public class SslTransportLayer implements TransportLayer {
*
* @param dst The buffer into which bytes are to be transferred
* @return The number of bytes read, possible zero or -1 if the channel has reached end-of-stream
+ * and no more data is available
* @throws IOException if some other I/O error occurs
*/
@Override
public int read(ByteBuffer dst) throws IOException {
- if (closing) return -1;
- int read = 0;
- if (!handshakeComplete) return read;
+ if (state == State.CLOSING) return -1;
+ else if (state != State.READY) return 0;
//if we have unread decrypted data in appReadBuffer read that into dst buffer.
+ int read = 0;
if (appReadBuffer.position() > 0) {
read = readFromAppBuffer(dst);
}
+ int netread = 0;
if (dst.remaining() > 0) {
netReadBuffer = Utils.ensureCapacity(netReadBuffer, netReadBufferSize());
- if (netReadBuffer.remaining() > 0) {
- int netread = socketChannel.read(netReadBuffer);
- if (netread == 0 && netReadBuffer.position() == 0) return read;
- else if (netread < 0) throw new EOFException("EOF during read");
- }
- do {
+ if (netReadBuffer.remaining() > 0)
+ netread = readFromSocketChannel();
+
+ while (netReadBuffer.position() > 0) {
netReadBuffer.flip();
SSLEngineResult unwrapResult = sslEngine.unwrap(netReadBuffer, appReadBuffer);
netReadBuffer.compact();
@@ -493,8 +551,12 @@ public class SslTransportLayer implements TransportLayer {
else
break;
}
- } while (netReadBuffer.position() != 0);
+ }
}
+ // If data has been read and unwrapped, return the data even if end-of-stream, channel will be closed
+ // on a subsequent poll.
+ if (read == 0 && netread < 0)
+ throw new EOFException("EOF during read");
return read;
}
@@ -553,8 +615,8 @@ public class SslTransportLayer implements TransportLayer {
@Override
public int write(ByteBuffer src) throws IOException {
int written = 0;
- if (closing) throw new IllegalStateException("Channel is in closing state");
- if (!handshakeComplete) return written;
+ if (state == State.CLOSING) throw new IllegalStateException("Channel is in closing state");
+ if (state != State.READY) return written;
if (!flush(netWriteBuffer))
return written;
@@ -662,7 +724,7 @@ public class SslTransportLayer implements TransportLayer {
public void addInterestOps(int ops) {
if (!key.isValid())
throw new CancelledKeyException();
- else if (!handshakeComplete)
+ else if (state != State.READY)
throw new IllegalStateException("handshake is not completed");
key.interestOps(key.interestOps() | ops);
@@ -676,7 +738,7 @@ public class SslTransportLayer implements TransportLayer {
public void removeInterestOps(int ops) {
if (!key.isValid())
throw new CancelledKeyException();
- else if (!handshakeComplete)
+ else if (state != State.READY)
throw new IllegalStateException("handshake is not completed");
key.interestOps(key.interestOps() & ~ops);
@@ -723,7 +785,12 @@ public class SslTransportLayer implements TransportLayer {
return netReadBuffer;
}
- private void handshakeFailure() {
+ /**
+ * SSL exceptions are propagated as authentication failures so that clients can avoid
+ * retries and report the failure. If `flush` is true, exceptions are propagated after
+ * any pending outgoing bytes are flushed to ensure that the peer is notified of the failure.
+ */
+ private void handshakeFailure(SSLException sslException, boolean flush) throws IOException {
//Release all resources such as internal buffers that SSLEngine is managing
sslEngine.closeOutbound();
try {
@@ -731,6 +798,17 @@ public class SslTransportLayer implements TransportLayer {
} catch (SSLException e) {
log.debug("SSLEngine.closeInBound() raised an exception.", e);
}
+
+ state = State.HANDSHAKE_FAILED;
+ handshakeException = new SslAuthenticationException("SSL handshake failed", sslException);
+ if (!flush || flush(netWriteBuffer))
+ throw handshakeException;
+ }
+
+ // If handshake has already failed, throw the authentication exception.
+ private void maybeThrowSslAuthenticationException() {
+ if (handshakeException != null)
+ throw handshakeException;
}
@Override
http://git-wip-us.apache.org/repos/asf/kafka/blob/d60f011d/clients/src/main/java/org/apache/kafka/common/network/TransportLayer.java
----------------------------------------------------------------------
diff --git a/clients/src/main/java/org/apache/kafka/common/network/TransportLayer.java b/clients/src/main/java/org/apache/kafka/common/network/TransportLayer.java
index be56ad5..23f866b 100644
--- a/clients/src/main/java/org/apache/kafka/common/network/TransportLayer.java
+++ b/clients/src/main/java/org/apache/kafka/common/network/TransportLayer.java
@@ -31,6 +31,7 @@ import java.nio.channels.GatheringByteChannel;
import java.security.Principal;
+import org.apache.kafka.common.errors.AuthenticationException;
public interface TransportLayer extends ScatteringByteChannel, GatheringByteChannel {
@@ -61,11 +62,14 @@ public interface TransportLayer extends ScatteringByteChannel, GatheringByteChan
/**
- * Performs SSL handshake hence is a no-op for the non-secure
- * implementation
- * @throws IOException
+ * This a no-op for the non-secure PLAINTEXT implementation. For SSL, this performs
+ * SSL handshake. The SSL handshake includes client authentication if configured using
+ * {@link org.apache.kafka.common.config.SslConfigsSslConfigs#SSL_CLIENT_AUTH_CONFIG}.
+ * @throws AuthenticationException if handshake fails due to an
+ * {@link javax.net.ssl.SSLExceptionSSLException}.
+ * @throws IOException if read or write fails with an I/O error.
*/
- void handshake() throws IOException;
+ void handshake() throws AuthenticationException, IOException;
/**
* Returns true if there are any pending writes
http://git-wip-us.apache.org/repos/asf/kafka/blob/d60f011d/clients/src/main/java/org/apache/kafka/common/security/authenticator/SaslClientAuthenticator.java
----------------------------------------------------------------------
diff --git a/clients/src/main/java/org/apache/kafka/common/security/authenticator/SaslClientAuthenticator.java b/clients/src/main/java/org/apache/kafka/common/security/authenticator/SaslClientAuthenticator.java
index 8207a5a..d9e4f0c 100644
--- a/clients/src/main/java/org/apache/kafka/common/security/authenticator/SaslClientAuthenticator.java
+++ b/clients/src/main/java/org/apache/kafka/common/security/authenticator/SaslClientAuthenticator.java
@@ -20,8 +20,8 @@ import org.apache.kafka.clients.CommonClientConfigs;
import org.apache.kafka.clients.NetworkClient;
import org.apache.kafka.common.KafkaException;
import org.apache.kafka.common.config.SaslConfigs;
-import org.apache.kafka.common.errors.AuthenticationException;
import org.apache.kafka.common.errors.IllegalSaslStateException;
+import org.apache.kafka.common.errors.SaslAuthenticationException;
import org.apache.kafka.common.errors.UnsupportedSaslMechanismException;
import org.apache.kafka.common.network.Authenticator;
import org.apache.kafka.common.network.Mode;
@@ -104,9 +104,6 @@ public class SaslClientAuthenticator implements Authenticator {
private RequestHeader currentRequestHeader;
// Version of SaslAuthenticate request/responses
private short saslAuthenticateVersion;
- // Sasl authentication error which may be one of NONE, UNSUPPORTED_SASL_MECHANISM, ILLEGAL_SASL_STATE,
- // SASL_AUTHENTICATION_FAILED or NETWORK_EXCEPTION
- private Errors error;
public SaslClientAuthenticator(Map<String, ?> configs,
String node,
@@ -125,7 +122,6 @@ public class SaslClientAuthenticator implements Authenticator {
this.transportLayer = transportLayer;
this.configs = configs;
this.saslAuthenticateVersion = DISABLE_KAFKA_SASL_AUTHENTICATE_HEADER;
- this.error = Errors.NONE;
try {
setSaslState(handshakeRequestEnable ? SaslState.SEND_APIVERSIONS_REQUEST : SaslState.INITIAL);
@@ -143,7 +139,7 @@ public class SaslClientAuthenticator implements Authenticator {
saslClient = createSaslClient();
} catch (Exception e) {
- throw new KafkaException("Failed to configure SaslClientAuthenticator", e);
+ throw new SaslAuthenticationException("Failed to configure SaslClientAuthenticator", e);
}
}
@@ -158,7 +154,7 @@ public class SaslClientAuthenticator implements Authenticator {
}
});
} catch (PrivilegedActionException e) {
- throw new KafkaException("Failed to create SaslClient with mechanism " + mechanism, e.getCause());
+ throw new SaslAuthenticationException("Failed to create SaslClient with mechanism " + mechanism, e.getCause());
}
}
@@ -236,11 +232,6 @@ public class SaslClientAuthenticator implements Authenticator {
}
}
- @Override
- public Errors error() {
- return error;
- }
-
private RequestHeader nextRequestHeader(ApiKeys apiKey, short version) {
String clientId = (String) configs.get(CommonClientConfigs.CLIENT_ID_CONFIG);
currentRequestHeader = new RequestHeader(apiKey, version, clientId, correlationId++);
@@ -345,8 +336,8 @@ public class SaslClientAuthenticator implements Authenticator {
} else {
SaslAuthenticateResponse response = (SaslAuthenticateResponse) receiveKafkaResponse();
if (response != null) {
- this.error = response.error();
- if (this.error != Errors.NONE) {
+ Errors error = response.error();
+ if (error != Errors.NONE) {
setSaslState(SaslState.FAILED);
String errMsg = response.errorMessage();
throw errMsg == null ? error.exception() : error.exception(errMsg);
@@ -360,7 +351,7 @@ public class SaslClientAuthenticator implements Authenticator {
private byte[] createSaslToken(final byte[] saslToken, boolean isInitial) throws SaslException {
if (saslToken == null)
- throw new SaslException("Error authenticating with the Kafka Broker: received a `null` saslToken.");
+ throw new IllegalSaslStateException("Error authenticating with the Kafka Broker: received a `null` saslToken.");
try {
if (isInitial && !saslClient.hasInitialResponse())
@@ -384,9 +375,9 @@ public class SaslClientAuthenticator implements Authenticator {
" Users must configure FQDN of kafka brokers when authenticating using SASL and" +
" `socketChannel.socket().getInetAddress().getHostName()` must match the hostname in `principal/hostname@realm`";
}
- error += " Kafka Client will go to AUTH_FAILED state.";
+ error += " Kafka Client will go to AUTHENTICATION_FAILED state.";
//Unwrap the SaslException inside `PrivilegedActionException`
- throw new SaslException(error, e.getCause());
+ throw new SaslAuthenticationException(error, e.getCause());
}
}
@@ -410,12 +401,12 @@ public class SaslClientAuthenticator implements Authenticator {
} catch (SchemaException | IllegalArgumentException e) {
LOG.debug("Invalid SASL mechanism response, server may be expecting only GSSAPI tokens");
setSaslState(SaslState.FAILED);
- throw new AuthenticationException("Invalid SASL mechanism response", e);
+ throw new IllegalSaslStateException("Invalid SASL mechanism response, server may be expecting a different protocol", e);
}
}
private void handleSaslHandshakeResponse(SaslHandshakeResponse response) {
- this.error = response.error();
+ Errors error = response.error();
if (error != Errors.NONE)
setSaslState(SaslState.FAILED);
switch (error) {
@@ -428,7 +419,7 @@ public class SaslClientAuthenticator implements Authenticator {
throw new IllegalSaslStateException(String.format("Unexpected handshake request with client mechanism %s, enabled mechanisms are %s",
mechanism, response.enabledMechanisms()));
default:
- throw new AuthenticationException(String.format("Unknown error code %s, client mechanism is %s, enabled mechanisms are %s",
+ throw new IllegalSaslStateException(String.format("Unknown error code %s, client mechanism is %s, enabled mechanisms are %s",
response.error(), mechanism, response.enabledMechanisms()));
}
}
http://git-wip-us.apache.org/repos/asf/kafka/blob/d60f011d/clients/src/main/java/org/apache/kafka/common/security/authenticator/SaslServerAuthenticator.java
----------------------------------------------------------------------
diff --git a/clients/src/main/java/org/apache/kafka/common/security/authenticator/SaslServerAuthenticator.java b/clients/src/main/java/org/apache/kafka/common/security/authenticator/SaslServerAuthenticator.java
index 6202131..fe57d27 100644
--- a/clients/src/main/java/org/apache/kafka/common/security/authenticator/SaslServerAuthenticator.java
+++ b/clients/src/main/java/org/apache/kafka/common/security/authenticator/SaslServerAuthenticator.java
@@ -123,8 +123,6 @@ public class SaslServerAuthenticator implements Authenticator {
private Send netOutBuffer;
// flag indicating if sasl tokens are sent as Kafka SaslAuthenticate request/responses
private boolean enableKafkaSaslAuthenticateHeaders;
- // authentication error if authentication failed
- private Errors error;
public SaslServerAuthenticator(Map<String, ?> configs,
String connectionId,
@@ -144,7 +142,6 @@ public class SaslServerAuthenticator implements Authenticator {
this.listenerName = listenerName;
this.securityProtocol = securityProtocol;
this.enableKafkaSaslAuthenticateHeaders = false;
- this.error = Errors.NONE;
this.transportLayer = transportLayer;
@@ -288,11 +285,6 @@ public class SaslServerAuthenticator implements Authenticator {
}
@Override
- public Errors error() {
- return error;
- }
-
- @Override
public boolean complete() {
return saslState == SaslState.COMPLETE;
}
@@ -366,13 +358,11 @@ public class SaslServerAuthenticator implements Authenticator {
KafkaPrincipal.ANONYMOUS, listenerName, securityProtocol);
RequestAndSize requestAndSize = requestContext.parseRequest(requestBuffer);
if (apiKey != ApiKeys.SASL_AUTHENTICATE) {
- this.error = Errors.ILLEGAL_SASL_STATE;
IllegalSaslStateException e = new IllegalSaslStateException("Unexpected Kafka request of type " + apiKey + " during SASL authentication.");
sendKafkaResponse(requestContext, requestAndSize.request.getErrorResponse(e));
throw e;
}
if (!apiKey.isVersionSupported(version)) {
- this.error = Errors.UNSUPPORTED_VERSION;
// We cannot create an error response if the request version of SaslAuthenticate is not supported
// This should not normally occur since clients typically check supported versions using ApiVersionsRequest
throw new UnsupportedVersionException("Version " + version + " is not supported for apiKey " + apiKey);
@@ -385,8 +375,7 @@ public class SaslServerAuthenticator implements Authenticator {
ByteBuffer responseBuf = responseToken == null ? EMPTY_BUFFER : ByteBuffer.wrap(responseToken);
sendKafkaResponse(requestContext, new SaslAuthenticateResponse(Errors.NONE, null, responseBuf));
} catch (SaslException e) {
- this.error = Errors.SASL_AUTHENTICATION_FAILED;
- sendKafkaResponse(requestContext, new SaslAuthenticateResponse(this.error,
+ sendKafkaResponse(requestContext, new SaslAuthenticateResponse(Errors.SASL_AUTHENTICATION_FAILED,
"Authentication failed due to invalid credentials with SASL mechanism " + saslMechanism));
throw e;
}
@@ -462,8 +451,7 @@ public class SaslServerAuthenticator implements Authenticator {
return clientMechanism;
} else {
LOG.debug("SASL mechanism '{}' requested by client is not supported", clientMechanism);
- this.error = Errors.UNSUPPORTED_SASL_MECHANISM;
- sendKafkaResponse(context, new SaslHandshakeResponse(this.error, enabledMechanisms));
+ sendKafkaResponse(context, new SaslHandshakeResponse(Errors.UNSUPPORTED_SASL_MECHANISM, enabledMechanisms));
throw new UnsupportedSaslMechanismException("Unsupported SASL mechanism " + clientMechanism);
}
}
http://git-wip-us.apache.org/repos/asf/kafka/blob/d60f011d/clients/src/test/java/org/apache/kafka/common/network/NioEchoServer.java
----------------------------------------------------------------------
diff --git a/clients/src/test/java/org/apache/kafka/common/network/NioEchoServer.java b/clients/src/test/java/org/apache/kafka/common/network/NioEchoServer.java
index e456d68..190fa3d 100644
--- a/clients/src/test/java/org/apache/kafka/common/network/NioEchoServer.java
+++ b/clients/src/test/java/org/apache/kafka/common/network/NioEchoServer.java
@@ -85,12 +85,14 @@ public class NioEchoServer extends Thread {
acceptorThread.start();
while (serverSocketChannel.isOpen()) {
selector.poll(1000);
- for (SocketChannel socketChannel : newChannels) {
- String id = id(socketChannel);
- selector.register(id, socketChannel);
- socketChannels.add(socketChannel);
+ synchronized (newChannels) {
+ for (SocketChannel socketChannel : newChannels) {
+ String id = id(socketChannel);
+ selector.register(id, socketChannel);
+ socketChannels.add(socketChannel);
+ }
+ newChannels.clear();
}
- newChannels.clear();
List<NetworkReceive> completedReceives = selector.completedReceives();
for (NetworkReceive rcv : completedReceives) {
http://git-wip-us.apache.org/repos/asf/kafka/blob/d60f011d/clients/src/test/java/org/apache/kafka/common/network/SslTransportLayerTest.java
----------------------------------------------------------------------
diff --git a/clients/src/test/java/org/apache/kafka/common/network/SslTransportLayerTest.java b/clients/src/test/java/org/apache/kafka/common/network/SslTransportLayerTest.java
index cffcc89..90c8cd5 100644
--- a/clients/src/test/java/org/apache/kafka/common/network/SslTransportLayerTest.java
+++ b/clients/src/test/java/org/apache/kafka/common/network/SslTransportLayerTest.java
@@ -48,8 +48,11 @@ import java.nio.channels.SocketChannel;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;
+import java.util.concurrent.atomic.AtomicInteger;
+import java.util.concurrent.atomic.AtomicLong;
import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;
@@ -159,7 +162,7 @@ public class SslTransportLayerTest {
InetSocketAddress addr = new InetSocketAddress("127.0.0.1", server.port());
selector.connect(node, addr, BUFFER_SIZE, BUFFER_SIZE);
- NetworkTestUtils.waitForChannelClose(selector, node, ChannelState.AUTHENTICATE.state());
+ NetworkTestUtils.waitForChannelClose(selector, node, ChannelState.State.AUTHENTICATION_FAILED);
}
/**
@@ -187,17 +190,13 @@ public class SslTransportLayerTest {
sslClientConfigs = clientCertStores.getTrustingConfig(serverCertStores);
// Create a server with endpoint validation enabled on the server SSL engine
- SslChannelBuilder serverChannelBuilder = new SslChannelBuilder(Mode.SERVER) {
+ SslChannelBuilder serverChannelBuilder = new TestSslChannelBuilder(Mode.SERVER) {
@Override
- protected SslTransportLayer buildTransportLayer(SslFactory sslFactory, String id, SelectionKey key, String host) throws IOException {
- SocketChannel socketChannel = (SocketChannel) key.channel();
- SSLEngine sslEngine = sslFactory.createSslEngine(host, socketChannel.socket().getPort());
+ protected TestSslTransportLayer newTransportLayer(String id, SelectionKey key, SSLEngine sslEngine) throws IOException {
SSLParameters sslParams = sslEngine.getSSLParameters();
sslParams.setEndpointIdentificationAlgorithm("HTTPS");
sslEngine.setSSLParameters(sslParams);
- TestSslTransportLayer transportLayer = new TestSslTransportLayer(id, key, sslEngine, BUFFER_SIZE, BUFFER_SIZE, BUFFER_SIZE);
- transportLayer.startHandshake();
- return transportLayer;
+ return super.newTransportLayer(id, key, sslEngine);
}
};
serverChannelBuilder.configure(sslServerConfigs);
@@ -211,7 +210,7 @@ public class SslTransportLayerTest {
NetworkTestUtils.checkClientConnection(selector, node, 100, 10);
}
-
+
/**
* Tests that server certificate with invalid host name is not accepted by
* a client that validates server endpoint. Server certificate uses
@@ -230,9 +229,9 @@ public class SslTransportLayerTest {
InetSocketAddress addr = new InetSocketAddress("localhost", server.port());
selector.connect(node, addr, BUFFER_SIZE, BUFFER_SIZE);
- NetworkTestUtils.waitForChannelClose(selector, node, ChannelState.AUTHENTICATE.state());
+ NetworkTestUtils.waitForChannelClose(selector, node, ChannelState.State.AUTHENTICATION_FAILED);
}
-
+
/**
* Tests that server certificate with invalid IP address is accepted by
* a client that has disabled endpoint validation
@@ -252,7 +251,7 @@ public class SslTransportLayerTest {
NetworkTestUtils.checkClientConnection(selector, node, 100, 10);
}
-
+
/**
* Tests that server accepts connections from clients with a trusted certificate
* when client authentication is required.
@@ -295,7 +294,7 @@ public class SslTransportLayerTest {
sslClientConfigs.remove(SslConfigs.SSL_KEY_PASSWORD_CONFIG);
createSelector(sslClientConfigs);
selector.connect(node, addr, BUFFER_SIZE, BUFFER_SIZE);
- NetworkTestUtils.waitForChannelClose(selector, node, ChannelState.AUTHENTICATE.state());
+ NetworkTestUtils.waitForChannelClose(selector, node, ChannelState.State.AUTHENTICATION_FAILED);
selector.close();
server.close();
@@ -308,7 +307,7 @@ public class SslTransportLayerTest {
selector.connect(node, addr, BUFFER_SIZE, BUFFER_SIZE);
NetworkTestUtils.checkClientConnection(selector, node, 100, 10);
}
-
+
/**
* Tests that server does not accept connections from clients with an untrusted certificate
* when client authentication is required.
@@ -323,9 +322,9 @@ public class SslTransportLayerTest {
InetSocketAddress addr = new InetSocketAddress("localhost", server.port());
selector.connect(node, addr, BUFFER_SIZE, BUFFER_SIZE);
- NetworkTestUtils.waitForChannelClose(selector, node, ChannelState.AUTHENTICATE.state());
+ NetworkTestUtils.waitForChannelClose(selector, node, ChannelState.State.AUTHENTICATION_FAILED);
}
-
+
/**
* Tests that server does not accept connections from clients which don't
* provide a certificate when client authentication is required.
@@ -335,7 +334,7 @@ public class SslTransportLayerTest {
String node = "0";
sslServerConfigs.put(BrokerSecurityConfigs.SSL_CLIENT_AUTH_CONFIG, "required");
server = createEchoServer(SecurityProtocol.SSL);
-
+
sslClientConfigs.remove(SslConfigs.SSL_KEYSTORE_LOCATION_CONFIG);
sslClientConfigs.remove(SslConfigs.SSL_KEYSTORE_PASSWORD_CONFIG);
sslClientConfigs.remove(SslConfigs.SSL_KEY_PASSWORD_CONFIG);
@@ -343,9 +342,9 @@ public class SslTransportLayerTest {
InetSocketAddress addr = new InetSocketAddress("localhost", server.port());
selector.connect(node, addr, BUFFER_SIZE, BUFFER_SIZE);
- NetworkTestUtils.waitForChannelClose(selector, node, ChannelState.AUTHENTICATE.state());
+ NetworkTestUtils.waitForChannelClose(selector, node, ChannelState.State.AUTHENTICATION_FAILED);
}
-
+
/**
* Tests that server accepts connections from a client configured
* with an untrusted certificate if client authentication is disabled
@@ -362,7 +361,7 @@ public class SslTransportLayerTest {
NetworkTestUtils.checkClientConnection(selector, node, 100, 10);
}
-
+
/**
* Tests that server accepts connections from a client that does not provide
* a certificate if client authentication is disabled
@@ -372,7 +371,7 @@ public class SslTransportLayerTest {
String node = "0";
sslServerConfigs.put(BrokerSecurityConfigs.SSL_CLIENT_AUTH_CONFIG, "none");
server = createEchoServer(SecurityProtocol.SSL);
-
+
sslClientConfigs.remove(SslConfigs.SSL_KEYSTORE_LOCATION_CONFIG);
sslClientConfigs.remove(SslConfigs.SSL_KEYSTORE_PASSWORD_CONFIG);
sslClientConfigs.remove(SslConfigs.SSL_KEY_PASSWORD_CONFIG);
@@ -382,7 +381,7 @@ public class SslTransportLayerTest {
NetworkTestUtils.checkClientConnection(selector, node, 100, 10);
}
-
+
/**
* Tests that server accepts connections from a client configured
* with a valid certificate if client authentication is requested
@@ -398,7 +397,7 @@ public class SslTransportLayerTest {
NetworkTestUtils.checkClientConnection(selector, node, 100, 10);
}
-
+
/**
* Tests that server accepts connections from a client that does not provide
* a certificate if client authentication is requested but not required
@@ -408,7 +407,7 @@ public class SslTransportLayerTest {
String node = "0";
sslServerConfigs.put(BrokerSecurityConfigs.SSL_CLIENT_AUTH_CONFIG, "requested");
server = createEchoServer(SecurityProtocol.SSL);
-
+
sslClientConfigs.remove(SslConfigs.SSL_KEYSTORE_LOCATION_CONFIG);
sslClientConfigs.remove(SslConfigs.SSL_KEYSTORE_PASSWORD_CONFIG);
sslClientConfigs.remove(SslConfigs.SSL_KEY_PASSWORD_CONFIG);
@@ -448,7 +447,7 @@ public class SslTransportLayerTest {
// Expected exception
}
}
-
+
/**
* Tests that channels cannot be created if keystore cannot be loaded
*/
@@ -481,7 +480,7 @@ public class SslTransportLayerTest {
NetworkTestUtils.checkClientConnection(selector, node, 100, 10);
}
-
+
/**
* Tests that client connections cannot be created to a server
* if key password is invalid
@@ -495,9 +494,9 @@ public class SslTransportLayerTest {
InetSocketAddress addr = new InetSocketAddress("localhost", server.port());
selector.connect(node, addr, BUFFER_SIZE, BUFFER_SIZE);
- NetworkTestUtils.waitForChannelClose(selector, node, ChannelState.AUTHENTICATE.state());
+ NetworkTestUtils.waitForChannelClose(selector, node, ChannelState.State.AUTHENTICATION_FAILED);
}
-
+
/**
* Tests that connections cannot be made with unsupported TLS versions
*/
@@ -506,15 +505,15 @@ public class SslTransportLayerTest {
String node = "0";
sslServerConfigs.put(SslConfigs.SSL_ENABLED_PROTOCOLS_CONFIG, Arrays.asList("TLSv1.2"));
server = createEchoServer(SecurityProtocol.SSL);
-
+
sslClientConfigs.put(SslConfigs.SSL_ENABLED_PROTOCOLS_CONFIG, Arrays.asList("TLSv1.1"));
createSelector(sslClientConfigs);
InetSocketAddress addr = new InetSocketAddress("localhost", server.port());
selector.connect(node, addr, BUFFER_SIZE, BUFFER_SIZE);
- NetworkTestUtils.waitForChannelClose(selector, node, ChannelState.AUTHENTICATE.state());
+ NetworkTestUtils.waitForChannelClose(selector, node, ChannelState.State.AUTHENTICATION_FAILED);
}
-
+
/**
* Tests that connections cannot be made with unsupported TLS cipher suites
*/
@@ -524,13 +523,13 @@ public class SslTransportLayerTest {
String[] cipherSuites = SSLContext.getDefault().getDefaultSSLParameters().getCipherSuites();
sslServerConfigs.put(SslConfigs.SSL_CIPHER_SUITES_CONFIG, Arrays.asList(cipherSuites[0]));
server = createEchoServer(SecurityProtocol.SSL);
-
+
sslClientConfigs.put(SslConfigs.SSL_CIPHER_SUITES_CONFIG, Arrays.asList(cipherSuites[1]));
createSelector(sslClientConfigs);
InetSocketAddress addr = new InetSocketAddress("localhost", server.port());
selector.connect(node, addr, BUFFER_SIZE, BUFFER_SIZE);
- NetworkTestUtils.waitForChannelClose(selector, node, ChannelState.AUTHENTICATE.state());
+ NetworkTestUtils.waitForChannelClose(selector, node, ChannelState.State.AUTHENTICATION_FAILED);
}
/**
@@ -546,7 +545,7 @@ public class SslTransportLayerTest {
NetworkTestUtils.checkClientConnection(selector, node, 64000, 10);
}
-
+
/**
* Tests handling of BUFFER_OVERFLOW during wrap when network write buffer is smaller than SSL session packet buffer size.
*/
@@ -602,14 +601,98 @@ public class SslTransportLayerTest {
}
assertTrue("Send time not recorded", channel.getAndResetNetworkThreadTimeNanos() > 0);
assertEquals("Time not reset", 0, channel.getAndResetNetworkThreadTimeNanos());
+ assertFalse("Unexpected bytes buffered", channel.hasBytesBuffered());
+ assertEquals(0, selector.completedReceives().size());
selector.unmute(node);
while (selector.completedReceives().isEmpty()) {
selector.poll(100L);
+ assertEquals(0, selector.numStagedReceives(channel));
}
assertTrue("Receive time not recorded", channel.getAndResetNetworkThreadTimeNanos() > 0);
}
+ /**
+ * Tests that IOExceptions from read during SSL handshake are not treated as authentication failures.
+ */
+ @Test
+ public void testIOExceptionsDuringHandshakeRead() throws Exception {
+ testIOExceptionsDuringHandshake(true, false);
+ }
+
+ /**
+ * Tests that IOExceptions from write during SSL handshake are not treated as authentication failures.
+ */
+ @Test
+ public void testIOExceptionsDuringHandshakeWrite() throws Exception {
+ testIOExceptionsDuringHandshake(false, true);
+ }
+
+ private void testIOExceptionsDuringHandshake(boolean failRead, boolean failWrite) throws Exception {
+ server = createEchoServer(SecurityProtocol.SSL);
+ TestSslChannelBuilder channelBuilder = new TestSslChannelBuilder(Mode.CLIENT);
+ boolean done = false;
+ for (int i = 1; i <= 100; i++) {
+ int readFailureIndex = failRead ? i : Integer.MAX_VALUE;
+ int flushFailureIndex = failWrite ? i : Integer.MAX_VALUE;
+ String node = String.valueOf(i);
+
+ channelBuilder.readFailureIndex = readFailureIndex;
+ channelBuilder.flushFailureIndex = flushFailureIndex;
+ channelBuilder.configure(sslClientConfigs);
+ this.selector = new Selector(5000, new Metrics(), new MockTime(), "MetricGroup", channelBuilder, new LogContext());
+
+ InetSocketAddress addr = new InetSocketAddress("localhost", server.port());
+ selector.connect(node, addr, BUFFER_SIZE, BUFFER_SIZE);
+ for (int j = 0; j < 30; j++) {
+ selector.poll(1000L);
+ KafkaChannel channel = selector.channel(node);
+ if (channel != null && channel.ready()) {
+ done = true;
+ break;
+ }
+ if (selector.disconnected().containsKey(node)) {
+ assertEquals(ChannelState.State.AUTHENTICATE, selector.disconnected().get(node).state());
+ break;
+ }
+ }
+ KafkaChannel channel = selector.channel(node);
+ if (channel != null)
+ assertTrue("Channel not ready or disconnected:" + channel.state().state(), channel.ready());
+ }
+ assertTrue("Too many invocations of read/write during SslTransportLayer.handshake()", done);
+ }
+
+ /**
+ * Tests that handshake failures are propagated only after writes complete, even when
+ * there are delays in writes to ensure that clients see an authentication exception
+ * rather than a connection failure.
+ */
+ @Test
+ public void testPeerNotifiedOfHandshakeFailure() throws Exception {
+ sslServerConfigs = serverCertStores.getUntrustingConfig();
+ sslServerConfigs.put(BrokerSecurityConfigs.SSL_CLIENT_AUTH_CONFIG, "required");
+
+ // Test without delay and a couple of delay counts to ensure delay applies to handshake failure
+ for (int i = 0; i < 3; i++) {
+ String node = "0";
+ TestSslChannelBuilder serverChannelBuilder = new TestSslChannelBuilder(Mode.SERVER);
+ serverChannelBuilder.configure(sslServerConfigs);
+ serverChannelBuilder.flushDelayCount = i;
+ server = new NioEchoServer(ListenerName.forSecurityProtocol(SecurityProtocol.SSL),
+ SecurityProtocol.SSL, new TestSecurityConfig(sslServerConfigs),
+ "localhost", serverChannelBuilder, null);
+ server.start();
+ createSelector(sslClientConfigs);
+ InetSocketAddress addr = new InetSocketAddress("localhost", server.port());
+ selector.connect(node, addr, BUFFER_SIZE, BUFFER_SIZE);
+
+ NetworkTestUtils.waitForChannelClose(selector, node, ChannelState.State.AUTHENTICATION_FAILED);
+ server.close();
+ selector.close();
+ }
+ }
+
@Test
public void testCloseSsl() throws Exception {
testClose(SecurityProtocol.SSL, new SslChannelBuilder(Mode.CLIENT));
@@ -654,24 +737,13 @@ public class SslTransportLayerTest {
private void createSelector(Map<String, Object> sslClientConfigs) {
createSelector(sslClientConfigs, null, null, null);
- }
+ }
private void createSelector(Map<String, Object> sslClientConfigs, final Integer netReadBufSize,
final Integer netWriteBufSize, final Integer appBufSize) {
-
- this.channelBuilder = new SslChannelBuilder(Mode.CLIENT) {
-
- @Override
- protected SslTransportLayer buildTransportLayer(SslFactory sslFactory, String id, SelectionKey key, String host) throws IOException {
- SocketChannel socketChannel = (SocketChannel) key.channel();
- SSLEngine sslEngine = sslFactory.createSslEngine(host, socketChannel.socket().getPort());
- TestSslTransportLayer transportLayer = new TestSslTransportLayer(id, key, sslEngine, netReadBufSize, netWriteBufSize, appBufSize);
- transportLayer.startHandshake();
- return transportLayer;
- }
-
-
- };
+ TestSslChannelBuilder channelBuilder = new TestSslChannelBuilder(Mode.CLIENT);
+ channelBuilder.configureBufferSizes(netReadBufSize, netWriteBufSize, appBufSize);
+ this.channelBuilder = channelBuilder;
this.channelBuilder.configure(sslClientConfigs);
this.selector = new Selector(5000, new Metrics(), new MockTime(), "MetricGroup", channelBuilder, new LogContext());
}
@@ -683,47 +755,111 @@ public class SslTransportLayerTest {
private NioEchoServer createEchoServer(SecurityProtocol securityProtocol) throws Exception {
return createEchoServer(ListenerName.forSecurityProtocol(securityProtocol), securityProtocol);
}
-
- /**
- * SSLTransportLayer with overrides for packet and application buffer size to test buffer resize
- * code path. The overridden buffer size starts with a small value and increases in size when the buffer
- * size is retrieved to handle overflow/underflow, until the actual session buffer size is reached.
- */
- private static class TestSslTransportLayer extends SslTransportLayer {
-
- private final ResizeableBufferSize netReadBufSize;
- private final ResizeableBufferSize netWriteBufSize;
- private final ResizeableBufferSize appBufSize;
-
- public TestSslTransportLayer(String channelId, SelectionKey key, SSLEngine sslEngine,
- Integer netReadBufSize, Integer netWriteBufSize, Integer appBufSize) throws IOException {
- super(channelId, key, sslEngine, false);
- this.netReadBufSize = new ResizeableBufferSize(netReadBufSize);
- this.netWriteBufSize = new ResizeableBufferSize(netWriteBufSize);
- this.appBufSize = new ResizeableBufferSize(appBufSize);
+
+ private static class TestSslChannelBuilder extends SslChannelBuilder {
+
+ private Integer netReadBufSizeOverride;
+ private Integer netWriteBufSizeOverride;
+ private Integer appBufSizeOverride;
+ long readFailureIndex = Long.MAX_VALUE;
+ long flushFailureIndex = Long.MAX_VALUE;
+ int flushDelayCount = 0;
+
+ public TestSslChannelBuilder(Mode mode) {
+ super(mode);
}
-
- @Override
- protected int netReadBufferSize() {
- ByteBuffer netReadBuffer = netReadBuffer();
- // netReadBufferSize() is invoked in SSLTransportLayer.read() prior to the read
- // operation. To avoid the read buffer being expanded too early, increase buffer size
- // only when read buffer is full. This ensures that BUFFER_UNDERFLOW is always
- // triggered in testNetReadBufferResize().
- boolean updateBufSize = netReadBuffer != null && !netReadBuffer().hasRemaining();
- return netReadBufSize.updateAndGet(super.netReadBufferSize(), updateBufSize);
+
+ public void configureBufferSizes(Integer netReadBufSize, Integer netWriteBufSize, Integer appBufSize) {
+ this.netReadBufSizeOverride = netReadBufSize;
+ this.netWriteBufSizeOverride = netWriteBufSize;
+ this.appBufSizeOverride = appBufSize;
}
-
+
@Override
- protected int netWriteBufferSize() {
- return netWriteBufSize.updateAndGet(super.netWriteBufferSize(), true);
+ protected SslTransportLayer buildTransportLayer(SslFactory sslFactory, String id, SelectionKey key, String host) throws IOException {
+ SocketChannel socketChannel = (SocketChannel) key.channel();
+ SSLEngine sslEngine = sslFactory.createSslEngine(host, socketChannel.socket().getPort());
+ TestSslTransportLayer transportLayer = newTransportLayer(id, key, sslEngine);
+ transportLayer.startHandshake();
+ return transportLayer;
}
- @Override
- protected int applicationBufferSize() {
- return appBufSize.updateAndGet(super.applicationBufferSize(), true);
+ protected TestSslTransportLayer newTransportLayer(String id, SelectionKey key, SSLEngine sslEngine) throws IOException {
+ return new TestSslTransportLayer(id, key, sslEngine);
+ }
+
+ /**
+ * SSLTransportLayer with overrides for testing including:
+ * <ul>
+ * <li>Overrides for packet and application buffer size to test buffer resize code path.
+ * The overridden buffer size starts with a small value and increases in size when the buffer size
+ * is retrieved to handle overflow/underflow, until the actual session buffer size is reached.</li>
+ * <li>IOException injection for reads and writes for testing exception handling during handshakes.</li>
+ * <li>Delayed writes to test handshake failure notifications to peer</li>
+ * </ul>
+ */
+ class TestSslTransportLayer extends SslTransportLayer {
+
+ private final ResizeableBufferSize netReadBufSize;
+ private final ResizeableBufferSize netWriteBufSize;
+ private final ResizeableBufferSize appBufSize;
+ private final AtomicLong numReadsRemaining;
+ private final AtomicLong numFlushesRemaining;
+ private final AtomicInteger numDelayedFlushesRemaining;
+
+ public TestSslTransportLayer(String channelId, SelectionKey key, SSLEngine sslEngine) throws IOException {
+ super(channelId, key, sslEngine, false);
+ this.netReadBufSize = new ResizeableBufferSize(netReadBufSizeOverride);
+ this.netWriteBufSize = new ResizeableBufferSize(netWriteBufSizeOverride);
+ this.appBufSize = new ResizeableBufferSize(appBufSizeOverride);
+ numReadsRemaining = new AtomicLong(readFailureIndex);
+ numFlushesRemaining = new AtomicLong(flushFailureIndex);
+ numDelayedFlushesRemaining = new AtomicInteger(flushDelayCount);
+ }
+
+ @Override
+ protected int netReadBufferSize() {
+ ByteBuffer netReadBuffer = netReadBuffer();
+ // netReadBufferSize() is invoked in SSLTransportLayer.read() prior to the read
+ // operation. To avoid the read buffer being expanded too early, increase buffer size
+ // only when read buffer is full. This ensures that BUFFER_UNDERFLOW is always
+ // triggered in testNetReadBufferResize().
+ boolean updateBufSize = netReadBuffer != null && !netReadBuffer().hasRemaining();
+ return netReadBufSize.updateAndGet(super.netReadBufferSize(), updateBufSize);
+ }
+
+ @Override
+ protected int netWriteBufferSize() {
+ return netWriteBufSize.updateAndGet(super.netWriteBufferSize(), true);
+ }
+
+ @Override
+ protected int applicationBufferSize() {
+ return appBufSize.updateAndGet(super.applicationBufferSize(), true);
+ }
+
+ @Override
+ protected int readFromSocketChannel() throws IOException {
+ if (numReadsRemaining.decrementAndGet() == 0 && !ready())
+ throw new IOException("Test exception during read");
+ return super.readFromSocketChannel();
+ }
+
+ @Override
+ protected boolean flush(ByteBuffer buf) throws IOException {
+ if (numFlushesRemaining.decrementAndGet() == 0 && !ready())
+ throw new IOException("Test exception during write");
+ else if (numDelayedFlushesRemaining.getAndDecrement() != 0)
+ return false;
+ resetDelayedFlush();
+ return super.flush(buf);
+ }
+
+ private void resetDelayedFlush() {
+ numDelayedFlushesRemaining.set(flushDelayCount);
+ }
}
-
+
private static class ResizeableBufferSize {
private Integer bufSizeOverride;
ResizeableBufferSize(Integer bufSizeOverride) {
@@ -740,5 +876,4 @@ public class SslTransportLayerTest {
}
}
}
-
}