You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@kafka.apache.org by sh...@apache.org on 2022/06/10 13:33:46 UTC
[kafka] branch trunk updated: [KAFKA-13848] Clients remain connected after SASL re-authentication f… (#12179)
This is an automated email from the ASF dual-hosted git repository.
showuon pushed a commit to branch trunk
in repository https://gitbox.apache.org/repos/asf/kafka.git
The following commit(s) were added to refs/heads/trunk by this push:
new 3d5b41e05f [KAFKA-13848] Clients remain connected after SASL re-authentication f… (#12179)
3d5b41e05f is described below
commit 3d5b41e05f0011dcf607f7d9cb269e8bb4558c57
Author: András Csáki <ac...@users.noreply.github.com>
AuthorDate: Fri Jun 10 15:33:33 2022 +0200
[KAFKA-13848] Clients remain connected after SASL re-authentication f… (#12179)
Clients remain connected and able to produce or consume despite an expired OAUTHBEARER token.
Root cause seems to be SaslServerAuthenticator#calcCompletionTimesAndReturnSessionLifetimeMs failing to set ReauthInfo#sessionExpirationTimeNanos when tokens have already expired (when session life time goes negative), in turn causing KafkaChannel#serverAuthenticationSessionExpired returning false and finally SocketServer not closing the channel.
The issue is observed with OAUTHBEARER but seems to have a wider impact on SASL re-authentication.
Reviewers: Luke Chen <sh...@gmail.com>, Tom Bentley <tb...@redhat.com>, Sam Barker <sb...@redhat.com>
---
build.gradle | 2 +-
.../authenticator/SaslServerAuthenticator.java | 32 ++-
.../apache/kafka/common/network/SelectorTest.java | 8 +-
.../authenticator/SaslServerAuthenticatorTest.java | 269 +++++++++++++++++++--
.../ExpiringCredentialRefreshingLoginTest.java | 7 +-
5 files changed, 266 insertions(+), 52 deletions(-)
diff --git a/build.gradle b/build.gradle
index 064e397c15..e34010c166 100644
--- a/build.gradle
+++ b/build.gradle
@@ -1245,7 +1245,7 @@ project(':clients') {
testImplementation libs.bcpkix
testImplementation libs.junitJupiter
- testImplementation libs.mockitoCore
+ testImplementation libs.mockitoInline
testRuntimeOnly libs.slf4jlog4j
testRuntimeOnly libs.jacksonDatabind
diff --git a/clients/src/main/java/org/apache/kafka/common/security/authenticator/SaslServerAuthenticator.java b/clients/src/main/java/org/apache/kafka/common/security/authenticator/SaslServerAuthenticator.java
index 6e35ee7a90..019723b6b4 100644
--- a/clients/src/main/java/org/apache/kafka/common/security/authenticator/SaslServerAuthenticator.java
+++ b/clients/src/main/java/org/apache/kafka/common/security/authenticator/SaslServerAuthenticator.java
@@ -673,30 +673,26 @@ public class SaslServerAuthenticator implements Authenticator {
Long credentialExpirationMs = (Long) saslServer
.getNegotiatedProperty(SaslInternalConfigs.CREDENTIAL_LIFETIME_MS_SASL_NEGOTIATED_PROPERTY_KEY);
Long connectionsMaxReauthMs = connectionsMaxReauthMsByMechanism.get(saslMechanism);
- if (credentialExpirationMs != null || connectionsMaxReauthMs != null) {
+ boolean maxReauthSet = connectionsMaxReauthMs != null && connectionsMaxReauthMs > 0;
+
+ if (credentialExpirationMs != null || maxReauthSet) {
if (credentialExpirationMs == null)
retvalSessionLifetimeMs = zeroIfNegative(connectionsMaxReauthMs);
- else if (connectionsMaxReauthMs == null)
+ else if (!maxReauthSet)
retvalSessionLifetimeMs = zeroIfNegative(credentialExpirationMs - authenticationEndMs);
else
- retvalSessionLifetimeMs = zeroIfNegative(
- Math.min(credentialExpirationMs - authenticationEndMs, connectionsMaxReauthMs));
- if (retvalSessionLifetimeMs > 0L)
- sessionExpirationTimeNanos = authenticationEndNanos + 1000 * 1000 * retvalSessionLifetimeMs;
+ retvalSessionLifetimeMs = zeroIfNegative(Math.min(credentialExpirationMs - authenticationEndMs, connectionsMaxReauthMs));
+
+ sessionExpirationTimeNanos = authenticationEndNanos + 1000 * 1000 * retvalSessionLifetimeMs;
}
+
if (credentialExpirationMs != null) {
- if (sessionExpirationTimeNanos != null)
- LOG.debug(
- "Authentication complete; session max lifetime from broker config={} ms, credential expiration={} ({} ms); session expiration = {} ({} ms), sending {} ms to client",
- connectionsMaxReauthMs, new Date(credentialExpirationMs),
- credentialExpirationMs - authenticationEndMs,
- new Date(authenticationEndMs + retvalSessionLifetimeMs), retvalSessionLifetimeMs,
- retvalSessionLifetimeMs);
- else
- LOG.debug(
- "Authentication complete; session max lifetime from broker config={} ms, credential expiration={} ({} ms); no session expiration, sending 0 ms to client",
- connectionsMaxReauthMs, new Date(credentialExpirationMs),
- credentialExpirationMs - authenticationEndMs);
+ LOG.debug(
+ "Authentication complete; session max lifetime from broker config={} ms, credential expiration={} ({} ms); session expiration = {} ({} ms), sending {} ms to client",
+ connectionsMaxReauthMs, new Date(credentialExpirationMs),
+ credentialExpirationMs - authenticationEndMs,
+ new Date(authenticationEndMs + retvalSessionLifetimeMs), retvalSessionLifetimeMs,
+ retvalSessionLifetimeMs);
} else {
if (sessionExpirationTimeNanos != null)
LOG.debug(
diff --git a/clients/src/test/java/org/apache/kafka/common/network/SelectorTest.java b/clients/src/test/java/org/apache/kafka/common/network/SelectorTest.java
index 59767f9dc2..09f14531de 100644
--- a/clients/src/test/java/org/apache/kafka/common/network/SelectorTest.java
+++ b/clients/src/test/java/org/apache/kafka/common/network/SelectorTest.java
@@ -776,14 +776,13 @@ public class SelectorTest {
when(kafkaChannel.selectionKey()).thenReturn(selectionKey);
when(selectionKey.channel()).thenReturn(SocketChannel.open());
when(selectionKey.readyOps()).thenReturn(SelectionKey.OP_CONNECT);
+ when(selectionKey.attachment()).thenReturn(kafkaChannel);
- selectionKey.attach(kafkaChannel);
Set<SelectionKey> selectionKeys = Utils.mkSet(selectionKey);
selector.pollSelectionKeys(selectionKeys, false, System.nanoTime());
assertFalse(selector.connected().contains(kafkaChannel.id()));
assertTrue(selector.disconnected().containsKey(kafkaChannel.id()));
- assertNull(selectionKey.attachment());
verify(kafkaChannel, atLeastOnce()).ready();
verify(kafkaChannel).disconnect();
@@ -971,8 +970,11 @@ public class SelectorTest {
SelectionKey selectionKey = mock(SelectionKey.class);
when(channel.selectionKey()).thenReturn(selectionKey);
when(selectionKey.isValid()).thenReturn(true);
+ when(selectionKey.isReadable()).thenReturn(true);
when(selectionKey.readyOps()).thenReturn(SelectionKey.OP_READ);
- selectionKey.attach(channel);
+ when(selectionKey.attachment())
+ .thenReturn(channel)
+ .thenReturn(null);
selectionKeys.add(selectionKey);
NetworkReceive receive = mock(NetworkReceive.class);
diff --git a/clients/src/test/java/org/apache/kafka/common/security/authenticator/SaslServerAuthenticatorTest.java b/clients/src/test/java/org/apache/kafka/common/security/authenticator/SaslServerAuthenticatorTest.java
index af0fedd4f5..50696ecf05 100644
--- a/clients/src/test/java/org/apache/kafka/common/security/authenticator/SaslServerAuthenticatorTest.java
+++ b/clients/src/test/java/org/apache/kafka/common/security/authenticator/SaslServerAuthenticatorTest.java
@@ -16,10 +16,12 @@
*/
package org.apache.kafka.common.security.authenticator;
-import java.net.InetAddress;
import org.apache.kafka.common.config.internals.BrokerSecurityConfigs;
import org.apache.kafka.common.errors.IllegalSaslStateException;
import org.apache.kafka.common.message.ApiMessageType;
+import org.apache.kafka.common.message.SaslAuthenticateRequestData;
+import org.apache.kafka.common.message.SaslHandshakeRequestData;
+import org.apache.kafka.common.network.ChannelBuilders;
import org.apache.kafka.common.network.ChannelMetadataRegistry;
import org.apache.kafka.common.network.ClientInformation;
import org.apache.kafka.common.network.DefaultChannelMetadataRegistry;
@@ -27,31 +29,55 @@ import org.apache.kafka.common.network.InvalidReceiveException;
import org.apache.kafka.common.network.ListenerName;
import org.apache.kafka.common.network.TransportLayer;
import org.apache.kafka.common.protocol.ApiKeys;
+import org.apache.kafka.common.requests.AbstractRequest;
import org.apache.kafka.common.requests.ApiVersionsRequest;
import org.apache.kafka.common.requests.ApiVersionsResponse;
+import org.apache.kafka.common.requests.RequestHeader;
import org.apache.kafka.common.requests.RequestTestUtils;
+import org.apache.kafka.common.requests.ResponseHeader;
+import org.apache.kafka.common.requests.SaslAuthenticateRequest;
+import org.apache.kafka.common.requests.SaslAuthenticateResponse;
+import org.apache.kafka.common.requests.SaslHandshakeRequest;
import org.apache.kafka.common.security.auth.AuthenticateCallbackHandler;
+import org.apache.kafka.common.security.auth.KafkaPrincipal;
+import org.apache.kafka.common.security.auth.KafkaPrincipalBuilder;
import org.apache.kafka.common.security.auth.SecurityProtocol;
-import org.apache.kafka.common.requests.RequestHeader;
+import org.apache.kafka.common.security.kerberos.KerberosShortNamer;
+import org.apache.kafka.common.security.oauthbearer.OAuthBearerLoginModule;
import org.apache.kafka.common.security.plain.PlainLoginModule;
+import org.apache.kafka.common.security.ssl.SslPrincipalMapper;
import org.apache.kafka.common.utils.AppInfoParser;
+import org.apache.kafka.common.utils.MockTime;
import org.apache.kafka.common.utils.Time;
+import org.junit.jupiter.api.Test;
+import org.mockito.Answers;
+import org.mockito.ArgumentCaptor;
+import org.mockito.MockedStatic;
+import org.mockito.Mockito;
+import org.mockito.stubbing.Answer;
import javax.security.auth.Subject;
+import javax.security.sasl.Sasl;
+import javax.security.sasl.SaslException;
+import javax.security.sasl.SaslServer;
import java.io.IOException;
+import java.net.InetAddress;
+import java.nio.Buffer;
import java.nio.ByteBuffer;
+import java.time.Duration;
import java.util.Collections;
import java.util.HashMap;
+import java.util.List;
import java.util.Map;
-
-import org.junit.jupiter.api.Test;
-import org.mockito.Answers;
+import java.util.stream.Collectors;
import static org.apache.kafka.common.security.scram.internals.ScramMechanism.SCRAM_SHA_256;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.fail;
import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.ArgumentMatchers.anyMap;
+import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
@@ -59,13 +85,15 @@ import static org.mockito.Mockito.when;
public class SaslServerAuthenticatorTest {
+ private final String clientId = "clientId";
+
@Test
public void testOversizeRequest() throws IOException {
TransportLayer transportLayer = mock(TransportLayer.class);
Map<String, ?> configs = Collections.singletonMap(BrokerSecurityConfigs.SASL_ENABLED_MECHANISMS_CONFIG,
Collections.singletonList(SCRAM_SHA_256.mechanismName()));
SaslServerAuthenticator authenticator = setupAuthenticator(configs, transportLayer,
- SCRAM_SHA_256.mechanismName(), new DefaultChannelMetadataRegistry());
+ SCRAM_SHA_256.mechanismName(), new DefaultChannelMetadataRegistry());
when(transportLayer.read(any(ByteBuffer.class))).then(invocation -> {
invocation.<ByteBuffer>getArgument(0).putInt(SaslServerAuthenticator.MAX_RECEIVE_SIZE + 1);
@@ -81,9 +109,9 @@ public class SaslServerAuthenticatorTest {
Map<String, ?> configs = Collections.singletonMap(BrokerSecurityConfigs.SASL_ENABLED_MECHANISMS_CONFIG,
Collections.singletonList(SCRAM_SHA_256.mechanismName()));
SaslServerAuthenticator authenticator = setupAuthenticator(configs, transportLayer,
- SCRAM_SHA_256.mechanismName(), new DefaultChannelMetadataRegistry());
+ SCRAM_SHA_256.mechanismName(), new DefaultChannelMetadataRegistry());
- RequestHeader header = new RequestHeader(ApiKeys.METADATA, (short) 0, "clientId", 13243);
+ RequestHeader header = new RequestHeader(ApiKeys.METADATA, (short) 0, clientId, 13243);
ByteBuffer headerBuffer = RequestTestUtils.serializeRequestHeader(header);
when(transportLayer.read(any(ByteBuffer.class))).then(invocation -> {
@@ -108,42 +136,223 @@ public class SaslServerAuthenticatorTest {
@Test
public void testOldestApiVersionsRequest() throws IOException {
testApiVersionsRequest(ApiKeys.API_VERSIONS.oldestVersion(),
- ClientInformation.UNKNOWN_NAME_OR_VERSION, ClientInformation.UNKNOWN_NAME_OR_VERSION);
+ ClientInformation.UNKNOWN_NAME_OR_VERSION, ClientInformation.UNKNOWN_NAME_OR_VERSION);
}
@Test
public void testLatestApiVersionsRequest() throws IOException {
testApiVersionsRequest(ApiKeys.API_VERSIONS.latestVersion(),
- "apache-kafka-java", AppInfoParser.getVersion());
+ "apache-kafka-java", AppInfoParser.getVersion());
}
- private void testApiVersionsRequest(short version, String expectedSoftwareName,
- String expectedSoftwareVersion) throws IOException {
- TransportLayer transportLayer = mock(TransportLayer.class, Answers.RETURNS_DEEP_STUBS);
+ @Test
+ public void testSessionExpiresAtTokenExpiryDespiteNoReauthIsSet() throws IOException {
+ String mechanism = OAuthBearerLoginModule.OAUTHBEARER_MECHANISM;
+ Duration tokenExpirationDuration = Duration.ofSeconds(1);
+ SaslServer saslServer = mock(SaslServer.class);
+
+ MockTime time = new MockTime();
+ try (
+ MockedStatic<?> ignored = mockSaslServer(saslServer, mechanism, time, tokenExpirationDuration);
+ MockedStatic<?> ignored2 = mockKafkaPrincipal("[principal-type]", "[principal-name");
+ TransportLayer transportLayer = mockTransportLayer()
+ ) {
+
+ SaslServerAuthenticator authenticator = getSaslServerAuthenticatorForOAuth(mechanism, transportLayer, time, 0L);
+
+ mockRequest(saslHandshakeRequest(mechanism), transportLayer);
+ authenticator.authenticate();
+
+ when(saslServer.isComplete()).thenReturn(false).thenReturn(true);
+ mockRequest(saslAuthenticateRequest(), transportLayer);
+ authenticator.authenticate();
+
+ long atTokenExpiryNanos = time.nanoseconds() + tokenExpirationDuration.toNanos();
+ assertEquals(atTokenExpiryNanos, authenticator.serverSessionExpirationTimeNanos());
+
+ ByteBuffer secondResponseSent = getResponses(transportLayer).get(1);
+ consumeSizeAndHeader(secondResponseSent);
+ SaslAuthenticateResponse response = SaslAuthenticateResponse.parse(secondResponseSent, (short) 2);
+ assertEquals(tokenExpirationDuration.toMillis(), response.sessionLifetimeMs());
+ }
+ }
+
+ @Test
+ public void testSessionExpiresAtMaxReauthTime() throws IOException {
+ String mechanism = OAuthBearerLoginModule.OAUTHBEARER_MECHANISM;
+ SaslServer saslServer = mock(SaslServer.class);
+ MockTime time = new MockTime(0, 1, 1000);
+ long maxReauthMs = 100L;
+ Duration tokenExpiryGreaterThanMaxReauth = Duration.ofMillis(maxReauthMs).multipliedBy(10);
+
+ try (
+ MockedStatic<?> ignored = mockSaslServer(saslServer, mechanism, time, tokenExpiryGreaterThanMaxReauth);
+ MockedStatic<?> ignored2 = mockKafkaPrincipal("[principal-type]", "[principal-name");
+ TransportLayer transportLayer = mockTransportLayer()
+ ) {
+
+ SaslServerAuthenticator authenticator = getSaslServerAuthenticatorForOAuth(mechanism, transportLayer, time, maxReauthMs);
+
+ mockRequest(saslHandshakeRequest(mechanism), transportLayer);
+ authenticator.authenticate();
+
+ when(saslServer.isComplete()).thenReturn(false).thenReturn(true);
+ mockRequest(saslAuthenticateRequest(), transportLayer);
+ authenticator.authenticate();
+
+ long atMaxReauthNanos = time.nanoseconds() + Duration.ofMillis(maxReauthMs).toNanos();
+ assertEquals(atMaxReauthNanos, authenticator.serverSessionExpirationTimeNanos());
+
+ ByteBuffer secondResponseSent = getResponses(transportLayer).get(1);
+ consumeSizeAndHeader(secondResponseSent);
+ SaslAuthenticateResponse response = SaslAuthenticateResponse.parse(secondResponseSent, (short) 2);
+ assertEquals(maxReauthMs, response.sessionLifetimeMs());
+ }
+ }
+
+ @Test
+ public void testSessionExpiresAtTokenExpiry() throws IOException {
+ String mechanism = OAuthBearerLoginModule.OAUTHBEARER_MECHANISM;
+ SaslServer saslServer = mock(SaslServer.class);
+ MockTime time = new MockTime(0, 1, 1000);
+ Duration tokenExpiryShorterThanMaxReauth = Duration.ofSeconds(2);
+ long maxReauthMs = tokenExpiryShorterThanMaxReauth.multipliedBy(2).toMillis();
+
+ try (
+ MockedStatic<?> ignored = mockSaslServer(saslServer, mechanism, time, tokenExpiryShorterThanMaxReauth);
+ MockedStatic<?> ignored2 = mockKafkaPrincipal("[principal-type]", "[principal-name");
+ TransportLayer transportLayer = mockTransportLayer()
+ ) {
+
+ SaslServerAuthenticator authenticator = getSaslServerAuthenticatorForOAuth(mechanism, transportLayer, time, maxReauthMs);
+
+ mockRequest(saslHandshakeRequest(mechanism), transportLayer);
+ authenticator.authenticate();
+
+ when(saslServer.isComplete()).thenReturn(false).thenReturn(true);
+ mockRequest(saslAuthenticateRequest(), transportLayer);
+ authenticator.authenticate();
+
+ long atTokenExpiryNanos = time.nanoseconds() + tokenExpiryShorterThanMaxReauth.toNanos();
+ assertEquals(atTokenExpiryNanos, authenticator.serverSessionExpirationTimeNanos());
+
+ ByteBuffer secondResponseSent = getResponses(transportLayer).get(1);
+ consumeSizeAndHeader(secondResponseSent);
+ SaslAuthenticateResponse response = SaslAuthenticateResponse.parse(secondResponseSent, (short) 2);
+ assertEquals(tokenExpiryShorterThanMaxReauth.toMillis(), response.sessionLifetimeMs());
+ }
+ }
+
+ private SaslServerAuthenticator getSaslServerAuthenticatorForOAuth(String mechanism, TransportLayer transportLayer, Time time, Long maxReauth) {
Map<String, ?> configs = Collections.singletonMap(BrokerSecurityConfigs.SASL_ENABLED_MECHANISMS_CONFIG,
- Collections.singletonList(SCRAM_SHA_256.mechanismName()));
+ Collections.singletonList(mechanism));
ChannelMetadataRegistry metadataRegistry = new DefaultChannelMetadataRegistry();
- SaslServerAuthenticator authenticator = setupAuthenticator(configs, transportLayer,
- SCRAM_SHA_256.mechanismName(), metadataRegistry);
- RequestHeader header = new RequestHeader(ApiKeys.API_VERSIONS, version, "clientId", 0);
+ return setupAuthenticator(configs, transportLayer, mechanism, metadataRegistry, time, maxReauth);
+ }
+
+ private MockedStatic<?> mockSaslServer(SaslServer saslServer, String mechanism, Time time, Duration tokenExpirationDuration) throws SaslException {
+ when(saslServer.getMechanismName()).thenReturn(mechanism);
+ when(saslServer.evaluateResponse(any())).thenReturn(new byte[]{});
+ long millisToExpiration = tokenExpirationDuration.toMillis();
+ when(saslServer.getNegotiatedProperty(eq(SaslInternalConfigs.CREDENTIAL_LIFETIME_MS_SASL_NEGOTIATED_PROPERTY_KEY)))
+ .thenReturn(time.milliseconds() + millisToExpiration);
+ return Mockito.mockStatic(Sasl.class, (Answer<SaslServer>) invocation -> saslServer);
+ }
+
+ private MockedStatic<?> mockKafkaPrincipal(String principalType, String name) {
+ KafkaPrincipalBuilder kafkaPrincipalBuilder = mock(KafkaPrincipalBuilder.class);
+ when(kafkaPrincipalBuilder.build(any())).thenReturn(new KafkaPrincipal(principalType, name));
+ MockedStatic<ChannelBuilders> channelBuilders = Mockito.mockStatic(ChannelBuilders.class, Answers.RETURNS_MOCKS);
+ channelBuilders.when(() ->
+ ChannelBuilders.createPrincipalBuilder(anyMap(), any(KerberosShortNamer.class), any(SslPrincipalMapper.class))
+ ).thenReturn(kafkaPrincipalBuilder);
+ return channelBuilders;
+ }
+
+ private void consumeSizeAndHeader(ByteBuffer responseBuffer) {
+ responseBuffer.getInt();
+ ResponseHeader.parse(responseBuffer, (short) 1);
+ }
+
+ private List<ByteBuffer> getResponses(TransportLayer transportLayer) throws IOException {
+ ArgumentCaptor<ByteBuffer[]> buffersCaptor = ArgumentCaptor.forClass(ByteBuffer[].class);
+ verify(transportLayer, times(2)).write(buffersCaptor.capture());
+ return buffersCaptor.getAllValues().stream()
+ .map(this::concatBuffers)
+ .collect(Collectors.toList());
+ }
+
+ private ByteBuffer concatBuffers(ByteBuffer[] buffers) {
+ int combinedCapacity = 0;
+ for (ByteBuffer buffer : buffers) {
+ combinedCapacity += buffer.capacity();
+ }
+ if (combinedCapacity > 0) {
+ ByteBuffer concat = ByteBuffer.allocate(combinedCapacity);
+ for (ByteBuffer buffer : buffers) {
+ concat.put(buffer);
+ }
+ return safeFlip(concat);
+ } else {
+ return ByteBuffer.allocate(0);
+ }
+ }
+
+ private ByteBuffer safeFlip(ByteBuffer buffer) {
+ return (ByteBuffer) ((Buffer) buffer).flip();
+ }
+
+ private SaslAuthenticateRequest saslAuthenticateRequest() {
+ SaslAuthenticateRequestData authenticateRequestData = new SaslAuthenticateRequestData();
+ return new SaslAuthenticateRequest.Builder(authenticateRequestData).build(ApiKeys.SASL_AUTHENTICATE.latestVersion());
+ }
+
+ private SaslHandshakeRequest saslHandshakeRequest(String mechanism) {
+ SaslHandshakeRequestData handshakeRequestData = new SaslHandshakeRequestData();
+ handshakeRequestData.setMechanism(mechanism);
+ return new SaslHandshakeRequest.Builder(handshakeRequestData).build(ApiKeys.SASL_HANDSHAKE.latestVersion());
+ }
+
+ private TransportLayer mockTransportLayer() throws IOException {
+ TransportLayer transportLayer = mock(TransportLayer.class, Answers.RETURNS_DEEP_STUBS);
+ when(transportLayer.socketChannel().socket().getInetAddress()).thenReturn(InetAddress.getLoopbackAddress());
+ when(transportLayer.write(any(ByteBuffer[].class))).thenReturn(Long.MAX_VALUE);
+ return transportLayer;
+ }
+
+ private void mockRequest(AbstractRequest request, TransportLayer transportLayer) throws IOException {
+ mockRequest(new RequestHeader(request.apiKey(), request.apiKey().latestVersion(), clientId, 0), request, transportLayer);
+ }
+
+ private void mockRequest(RequestHeader header, AbstractRequest request, TransportLayer transportLayer) throws IOException {
ByteBuffer headerBuffer = RequestTestUtils.serializeRequestHeader(header);
- ApiVersionsRequest request = new ApiVersionsRequest.Builder().build(version);
ByteBuffer requestBuffer = request.serialize();
requestBuffer.rewind();
- when(transportLayer.socketChannel().socket().getInetAddress()).thenReturn(InetAddress.getLoopbackAddress());
-
when(transportLayer.read(any(ByteBuffer.class))).then(invocation -> {
invocation.<ByteBuffer>getArgument(0).putInt(headerBuffer.remaining() + requestBuffer.remaining());
return 4;
}).then(invocation -> {
invocation.<ByteBuffer>getArgument(0)
- .put(headerBuffer.duplicate())
- .put(requestBuffer.duplicate());
+ .put(headerBuffer.duplicate())
+ .put(requestBuffer.duplicate());
return headerBuffer.remaining() + requestBuffer.remaining();
});
+ }
+
+ private void testApiVersionsRequest(short version, String expectedSoftwareName,
+ String expectedSoftwareVersion) throws IOException {
+ TransportLayer transportLayer = mockTransportLayer();
+ Map<String, ?> configs = Collections.singletonMap(BrokerSecurityConfigs.SASL_ENABLED_MECHANISMS_CONFIG,
+ Collections.singletonList(SCRAM_SHA_256.mechanismName()));
+ ChannelMetadataRegistry metadataRegistry = new DefaultChannelMetadataRegistry();
+ SaslServerAuthenticator authenticator = setupAuthenticator(configs, transportLayer, SCRAM_SHA_256.mechanismName(), metadataRegistry);
+
+ RequestHeader header = new RequestHeader(ApiKeys.API_VERSIONS, version, clientId, 0);
+ ApiVersionsRequest request = new ApiVersionsRequest.Builder().build(version);
+ mockRequest(header, request, transportLayer);
authenticator.authenticate();
@@ -155,16 +364,24 @@ public class SaslServerAuthenticatorTest {
private SaslServerAuthenticator setupAuthenticator(Map<String, ?> configs, TransportLayer transportLayer,
String mechanism, ChannelMetadataRegistry metadataRegistry) {
+ return setupAuthenticator(configs, transportLayer, mechanism, metadataRegistry, new MockTime(), null);
+ }
+
+ private SaslServerAuthenticator setupAuthenticator(Map<String, ?> configs, TransportLayer transportLayer,
+ String mechanism, ChannelMetadataRegistry metadataRegistry, Time time, Long maxReauth) {
TestJaasConfig jaasConfig = new TestJaasConfig();
- jaasConfig.addEntry("jaasContext", PlainLoginModule.class.getName(), new HashMap<String, Object>());
+ jaasConfig.addEntry("jaasContext", PlainLoginModule.class.getName(), new HashMap<>());
Map<String, Subject> subjects = Collections.singletonMap(mechanism, new Subject());
Map<String, AuthenticateCallbackHandler> callbackHandlers = Collections.singletonMap(
mechanism, new SaslServerCallbackHandler());
ApiVersionsResponse apiVersionsResponse = ApiVersionsResponse.defaultApiVersionsResponse(
- ApiMessageType.ListenerType.ZK_BROKER);
+ ApiMessageType.ListenerType.ZK_BROKER);
+ Map<String, Long> connectionsMaxReauthMsByMechanism = maxReauth != null ?
+ Collections.singletonMap(mechanism, maxReauth) : Collections.emptyMap();
+
return new SaslServerAuthenticator(configs, callbackHandlers, "node", subjects, null,
- new ListenerName("ssl"), SecurityProtocol.SASL_SSL, transportLayer, Collections.emptyMap(),
- metadataRegistry, Time.SYSTEM, () -> apiVersionsResponse);
+ new ListenerName("ssl"), SecurityProtocol.SASL_SSL, transportLayer, connectionsMaxReauthMsByMechanism,
+ metadataRegistry, time, () -> apiVersionsResponse);
}
}
diff --git a/clients/src/test/java/org/apache/kafka/common/security/oauthbearer/internals/expiring/ExpiringCredentialRefreshingLoginTest.java b/clients/src/test/java/org/apache/kafka/common/security/oauthbearer/internals/expiring/ExpiringCredentialRefreshingLoginTest.java
index 9a77c738d2..85f6622f09 100644
--- a/clients/src/test/java/org/apache/kafka/common/security/oauthbearer/internals/expiring/ExpiringCredentialRefreshingLoginTest.java
+++ b/clients/src/test/java/org/apache/kafka/common/security/oauthbearer/internals/expiring/ExpiringCredentialRefreshingLoginTest.java
@@ -48,6 +48,7 @@ import org.apache.kafka.common.utils.Time;
import org.junit.jupiter.api.Test;
import org.mockito.InOrder;
import org.mockito.Mockito;
+import org.mockito.internal.util.MockUtil;
public class ExpiringCredentialRefreshingLoginTest {
private static final Configuration EMPTY_WILDCARD_CONFIGURATION;
@@ -188,8 +189,7 @@ public class ExpiringCredentialRefreshingLoginTest {
super("contextName", null, null, EMPTY_WILDCARD_CONFIGURATION);
this.testExpiringCredentialRefreshingLogin = Objects.requireNonNull(testExpiringCredentialRefreshingLogin);
// sanity check to make sure it is likely a mock
- if (Objects.requireNonNull(mockLoginContext).getClass().equals(LoginContext.class)
- || mockLoginContext.getClass().equals(getClass()))
+ if (!MockUtil.isMock(mockLoginContext))
throw new IllegalArgumentException();
this.mockLoginContext = mockLoginContext;
}
@@ -233,8 +233,7 @@ public class ExpiringCredentialRefreshingLoginTest {
public void configure(LoginContext mockLoginContext,
TestExpiringCredentialRefreshingLogin testExpiringCredentialRefreshingLogin) throws LoginException {
// sanity check to make sure it is likely a mock
- if (Objects.requireNonNull(mockLoginContext).getClass().equals(LoginContext.class)
- || mockLoginContext.getClass().equals(TestLoginContext.class))
+ if (!MockUtil.isMock(mockLoginContext))
throw new IllegalArgumentException();
this.testLoginContext = new TestLoginContext(Objects.requireNonNull(testExpiringCredentialRefreshingLogin),
mockLoginContext);