You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@kafka.apache.org by sh...@apache.org on 2022/06/10 13:33:46 UTC

[kafka] branch trunk updated: [KAFKA-13848] Clients remain connected after SASL re-authentication f… (#12179)

This is an automated email from the ASF dual-hosted git repository.

showuon pushed a commit to branch trunk
in repository https://gitbox.apache.org/repos/asf/kafka.git


The following commit(s) were added to refs/heads/trunk by this push:
     new 3d5b41e05f [KAFKA-13848] Clients remain connected after SASL re-authentication f… (#12179)
3d5b41e05f is described below

commit 3d5b41e05f0011dcf607f7d9cb269e8bb4558c57
Author: András Csáki <ac...@users.noreply.github.com>
AuthorDate: Fri Jun 10 15:33:33 2022 +0200

    [KAFKA-13848] Clients remain connected after SASL re-authentication f… (#12179)
    
    
    
    Clients remain connected and able to produce or consume despite an expired OAUTHBEARER token.
    
    Root cause seems to be SaslServerAuthenticator#calcCompletionTimesAndReturnSessionLifetimeMs failing to set ReauthInfo#sessionExpirationTimeNanos when tokens have already expired (when session life time goes negative), in turn causing KafkaChannel#serverAuthenticationSessionExpired returning false and finally SocketServer not closing the channel.
    
    The issue is observed with OAUTHBEARER but seems to have a wider impact on SASL re-authentication.
    
    Reviewers: Luke Chen <sh...@gmail.com>, Tom Bentley <tb...@redhat.com>, Sam Barker <sb...@redhat.com>
---
 build.gradle                                       |   2 +-
 .../authenticator/SaslServerAuthenticator.java     |  32 ++-
 .../apache/kafka/common/network/SelectorTest.java  |   8 +-
 .../authenticator/SaslServerAuthenticatorTest.java | 269 +++++++++++++++++++--
 .../ExpiringCredentialRefreshingLoginTest.java     |   7 +-
 5 files changed, 266 insertions(+), 52 deletions(-)

diff --git a/build.gradle b/build.gradle
index 064e397c15..e34010c166 100644
--- a/build.gradle
+++ b/build.gradle
@@ -1245,7 +1245,7 @@ project(':clients') {
 
     testImplementation libs.bcpkix
     testImplementation libs.junitJupiter
-    testImplementation libs.mockitoCore
+    testImplementation libs.mockitoInline
 
     testRuntimeOnly libs.slf4jlog4j
     testRuntimeOnly libs.jacksonDatabind
diff --git a/clients/src/main/java/org/apache/kafka/common/security/authenticator/SaslServerAuthenticator.java b/clients/src/main/java/org/apache/kafka/common/security/authenticator/SaslServerAuthenticator.java
index 6e35ee7a90..019723b6b4 100644
--- a/clients/src/main/java/org/apache/kafka/common/security/authenticator/SaslServerAuthenticator.java
+++ b/clients/src/main/java/org/apache/kafka/common/security/authenticator/SaslServerAuthenticator.java
@@ -673,30 +673,26 @@ public class SaslServerAuthenticator implements Authenticator {
             Long credentialExpirationMs = (Long) saslServer
                     .getNegotiatedProperty(SaslInternalConfigs.CREDENTIAL_LIFETIME_MS_SASL_NEGOTIATED_PROPERTY_KEY);
             Long connectionsMaxReauthMs = connectionsMaxReauthMsByMechanism.get(saslMechanism);
-            if (credentialExpirationMs != null || connectionsMaxReauthMs != null) {
+            boolean maxReauthSet = connectionsMaxReauthMs != null && connectionsMaxReauthMs > 0;
+
+            if (credentialExpirationMs != null || maxReauthSet) {
                 if (credentialExpirationMs == null)
                     retvalSessionLifetimeMs = zeroIfNegative(connectionsMaxReauthMs);
-                else if (connectionsMaxReauthMs == null)
+                else if (!maxReauthSet)
                     retvalSessionLifetimeMs = zeroIfNegative(credentialExpirationMs - authenticationEndMs);
                 else
-                    retvalSessionLifetimeMs = zeroIfNegative(
-                            Math.min(credentialExpirationMs - authenticationEndMs, connectionsMaxReauthMs));
-                if (retvalSessionLifetimeMs > 0L)
-                    sessionExpirationTimeNanos = authenticationEndNanos + 1000 * 1000 * retvalSessionLifetimeMs;
+                    retvalSessionLifetimeMs = zeroIfNegative(Math.min(credentialExpirationMs - authenticationEndMs, connectionsMaxReauthMs));
+
+                sessionExpirationTimeNanos = authenticationEndNanos + 1000 * 1000 * retvalSessionLifetimeMs;
             }
+
             if (credentialExpirationMs != null) {
-                if (sessionExpirationTimeNanos != null)
-                    LOG.debug(
-                            "Authentication complete; session max lifetime from broker config={} ms, credential expiration={} ({} ms); session expiration = {} ({} ms), sending {} ms to client",
-                            connectionsMaxReauthMs, new Date(credentialExpirationMs),
-                            credentialExpirationMs - authenticationEndMs,
-                            new Date(authenticationEndMs + retvalSessionLifetimeMs), retvalSessionLifetimeMs,
-                            retvalSessionLifetimeMs);
-                else
-                    LOG.debug(
-                            "Authentication complete; session max lifetime from broker config={} ms, credential expiration={} ({} ms); no session expiration, sending 0 ms to client",
-                            connectionsMaxReauthMs, new Date(credentialExpirationMs),
-                            credentialExpirationMs - authenticationEndMs);
+                LOG.debug(
+                        "Authentication complete; session max lifetime from broker config={} ms, credential expiration={} ({} ms); session expiration = {} ({} ms), sending {} ms to client",
+                        connectionsMaxReauthMs, new Date(credentialExpirationMs),
+                        credentialExpirationMs - authenticationEndMs,
+                        new Date(authenticationEndMs + retvalSessionLifetimeMs), retvalSessionLifetimeMs,
+                        retvalSessionLifetimeMs);
             } else {
                 if (sessionExpirationTimeNanos != null)
                     LOG.debug(
diff --git a/clients/src/test/java/org/apache/kafka/common/network/SelectorTest.java b/clients/src/test/java/org/apache/kafka/common/network/SelectorTest.java
index 59767f9dc2..09f14531de 100644
--- a/clients/src/test/java/org/apache/kafka/common/network/SelectorTest.java
+++ b/clients/src/test/java/org/apache/kafka/common/network/SelectorTest.java
@@ -776,14 +776,13 @@ public class SelectorTest {
         when(kafkaChannel.selectionKey()).thenReturn(selectionKey);
         when(selectionKey.channel()).thenReturn(SocketChannel.open());
         when(selectionKey.readyOps()).thenReturn(SelectionKey.OP_CONNECT);
+        when(selectionKey.attachment()).thenReturn(kafkaChannel);
 
-        selectionKey.attach(kafkaChannel);
         Set<SelectionKey> selectionKeys = Utils.mkSet(selectionKey);
         selector.pollSelectionKeys(selectionKeys, false, System.nanoTime());
 
         assertFalse(selector.connected().contains(kafkaChannel.id()));
         assertTrue(selector.disconnected().containsKey(kafkaChannel.id()));
-        assertNull(selectionKey.attachment());
 
         verify(kafkaChannel, atLeastOnce()).ready();
         verify(kafkaChannel).disconnect();
@@ -971,8 +970,11 @@ public class SelectorTest {
             SelectionKey selectionKey = mock(SelectionKey.class);
             when(channel.selectionKey()).thenReturn(selectionKey);
             when(selectionKey.isValid()).thenReturn(true);
+            when(selectionKey.isReadable()).thenReturn(true);
             when(selectionKey.readyOps()).thenReturn(SelectionKey.OP_READ);
-            selectionKey.attach(channel);
+            when(selectionKey.attachment())
+                    .thenReturn(channel)
+                    .thenReturn(null);
             selectionKeys.add(selectionKey);
 
             NetworkReceive receive = mock(NetworkReceive.class);
diff --git a/clients/src/test/java/org/apache/kafka/common/security/authenticator/SaslServerAuthenticatorTest.java b/clients/src/test/java/org/apache/kafka/common/security/authenticator/SaslServerAuthenticatorTest.java
index af0fedd4f5..50696ecf05 100644
--- a/clients/src/test/java/org/apache/kafka/common/security/authenticator/SaslServerAuthenticatorTest.java
+++ b/clients/src/test/java/org/apache/kafka/common/security/authenticator/SaslServerAuthenticatorTest.java
@@ -16,10 +16,12 @@
  */
 package org.apache.kafka.common.security.authenticator;
 
-import java.net.InetAddress;
 import org.apache.kafka.common.config.internals.BrokerSecurityConfigs;
 import org.apache.kafka.common.errors.IllegalSaslStateException;
 import org.apache.kafka.common.message.ApiMessageType;
+import org.apache.kafka.common.message.SaslAuthenticateRequestData;
+import org.apache.kafka.common.message.SaslHandshakeRequestData;
+import org.apache.kafka.common.network.ChannelBuilders;
 import org.apache.kafka.common.network.ChannelMetadataRegistry;
 import org.apache.kafka.common.network.ClientInformation;
 import org.apache.kafka.common.network.DefaultChannelMetadataRegistry;
@@ -27,31 +29,55 @@ import org.apache.kafka.common.network.InvalidReceiveException;
 import org.apache.kafka.common.network.ListenerName;
 import org.apache.kafka.common.network.TransportLayer;
 import org.apache.kafka.common.protocol.ApiKeys;
+import org.apache.kafka.common.requests.AbstractRequest;
 import org.apache.kafka.common.requests.ApiVersionsRequest;
 import org.apache.kafka.common.requests.ApiVersionsResponse;
+import org.apache.kafka.common.requests.RequestHeader;
 import org.apache.kafka.common.requests.RequestTestUtils;
+import org.apache.kafka.common.requests.ResponseHeader;
+import org.apache.kafka.common.requests.SaslAuthenticateRequest;
+import org.apache.kafka.common.requests.SaslAuthenticateResponse;
+import org.apache.kafka.common.requests.SaslHandshakeRequest;
 import org.apache.kafka.common.security.auth.AuthenticateCallbackHandler;
+import org.apache.kafka.common.security.auth.KafkaPrincipal;
+import org.apache.kafka.common.security.auth.KafkaPrincipalBuilder;
 import org.apache.kafka.common.security.auth.SecurityProtocol;
-import org.apache.kafka.common.requests.RequestHeader;
+import org.apache.kafka.common.security.kerberos.KerberosShortNamer;
+import org.apache.kafka.common.security.oauthbearer.OAuthBearerLoginModule;
 import org.apache.kafka.common.security.plain.PlainLoginModule;
+import org.apache.kafka.common.security.ssl.SslPrincipalMapper;
 import org.apache.kafka.common.utils.AppInfoParser;
+import org.apache.kafka.common.utils.MockTime;
 import org.apache.kafka.common.utils.Time;
+import org.junit.jupiter.api.Test;
+import org.mockito.Answers;
+import org.mockito.ArgumentCaptor;
+import org.mockito.MockedStatic;
+import org.mockito.Mockito;
+import org.mockito.stubbing.Answer;
 
 import javax.security.auth.Subject;
+import javax.security.sasl.Sasl;
+import javax.security.sasl.SaslException;
+import javax.security.sasl.SaslServer;
 import java.io.IOException;
+import java.net.InetAddress;
+import java.nio.Buffer;
 import java.nio.ByteBuffer;
+import java.time.Duration;
 import java.util.Collections;
 import java.util.HashMap;
+import java.util.List;
 import java.util.Map;
-
-import org.junit.jupiter.api.Test;
-import org.mockito.Answers;
+import java.util.stream.Collectors;
 
 import static org.apache.kafka.common.security.scram.internals.ScramMechanism.SCRAM_SHA_256;
 import static org.junit.jupiter.api.Assertions.assertEquals;
 import static org.junit.jupiter.api.Assertions.assertThrows;
 import static org.junit.jupiter.api.Assertions.fail;
 import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.ArgumentMatchers.anyMap;
+import static org.mockito.ArgumentMatchers.eq;
 import static org.mockito.Mockito.mock;
 import static org.mockito.Mockito.times;
 import static org.mockito.Mockito.verify;
@@ -59,13 +85,15 @@ import static org.mockito.Mockito.when;
 
 public class SaslServerAuthenticatorTest {
 
+    private final String clientId = "clientId";
+    
     @Test
     public void testOversizeRequest() throws IOException {
         TransportLayer transportLayer = mock(TransportLayer.class);
         Map<String, ?> configs = Collections.singletonMap(BrokerSecurityConfigs.SASL_ENABLED_MECHANISMS_CONFIG,
                 Collections.singletonList(SCRAM_SHA_256.mechanismName()));
         SaslServerAuthenticator authenticator = setupAuthenticator(configs, transportLayer,
-            SCRAM_SHA_256.mechanismName(), new DefaultChannelMetadataRegistry());
+                SCRAM_SHA_256.mechanismName(), new DefaultChannelMetadataRegistry());
 
         when(transportLayer.read(any(ByteBuffer.class))).then(invocation -> {
             invocation.<ByteBuffer>getArgument(0).putInt(SaslServerAuthenticator.MAX_RECEIVE_SIZE + 1);
@@ -81,9 +109,9 @@ public class SaslServerAuthenticatorTest {
         Map<String, ?> configs = Collections.singletonMap(BrokerSecurityConfigs.SASL_ENABLED_MECHANISMS_CONFIG,
                 Collections.singletonList(SCRAM_SHA_256.mechanismName()));
         SaslServerAuthenticator authenticator = setupAuthenticator(configs, transportLayer,
-            SCRAM_SHA_256.mechanismName(), new DefaultChannelMetadataRegistry());
+                SCRAM_SHA_256.mechanismName(), new DefaultChannelMetadataRegistry());
 
-        RequestHeader header = new RequestHeader(ApiKeys.METADATA, (short) 0, "clientId", 13243);
+        RequestHeader header = new RequestHeader(ApiKeys.METADATA, (short) 0, clientId, 13243);
         ByteBuffer headerBuffer = RequestTestUtils.serializeRequestHeader(header);
 
         when(transportLayer.read(any(ByteBuffer.class))).then(invocation -> {
@@ -108,42 +136,223 @@ public class SaslServerAuthenticatorTest {
     @Test
     public void testOldestApiVersionsRequest() throws IOException {
         testApiVersionsRequest(ApiKeys.API_VERSIONS.oldestVersion(),
-            ClientInformation.UNKNOWN_NAME_OR_VERSION, ClientInformation.UNKNOWN_NAME_OR_VERSION);
+                ClientInformation.UNKNOWN_NAME_OR_VERSION, ClientInformation.UNKNOWN_NAME_OR_VERSION);
     }
 
     @Test
     public void testLatestApiVersionsRequest() throws IOException {
         testApiVersionsRequest(ApiKeys.API_VERSIONS.latestVersion(),
-            "apache-kafka-java", AppInfoParser.getVersion());
+                "apache-kafka-java", AppInfoParser.getVersion());
     }
 
-    private void testApiVersionsRequest(short version, String expectedSoftwareName,
-                                       String expectedSoftwareVersion) throws IOException {
-        TransportLayer transportLayer = mock(TransportLayer.class, Answers.RETURNS_DEEP_STUBS);
+    @Test
+    public void testSessionExpiresAtTokenExpiryDespiteNoReauthIsSet() throws IOException {
+        String mechanism = OAuthBearerLoginModule.OAUTHBEARER_MECHANISM;
+        Duration tokenExpirationDuration = Duration.ofSeconds(1);
+        SaslServer saslServer = mock(SaslServer.class);
+
+        MockTime time = new MockTime();
+        try (
+                MockedStatic<?> ignored = mockSaslServer(saslServer, mechanism, time, tokenExpirationDuration);
+                MockedStatic<?> ignored2 = mockKafkaPrincipal("[principal-type]", "[principal-name");
+                TransportLayer transportLayer = mockTransportLayer()
+        ) {
+
+            SaslServerAuthenticator authenticator = getSaslServerAuthenticatorForOAuth(mechanism, transportLayer, time, 0L);
+
+            mockRequest(saslHandshakeRequest(mechanism), transportLayer);
+            authenticator.authenticate();
+
+            when(saslServer.isComplete()).thenReturn(false).thenReturn(true);
+            mockRequest(saslAuthenticateRequest(), transportLayer);
+            authenticator.authenticate();
+
+            long atTokenExpiryNanos = time.nanoseconds() + tokenExpirationDuration.toNanos();
+            assertEquals(atTokenExpiryNanos, authenticator.serverSessionExpirationTimeNanos());
+
+            ByteBuffer secondResponseSent = getResponses(transportLayer).get(1);
+            consumeSizeAndHeader(secondResponseSent);
+            SaslAuthenticateResponse response = SaslAuthenticateResponse.parse(secondResponseSent, (short) 2);
+            assertEquals(tokenExpirationDuration.toMillis(), response.sessionLifetimeMs());
+        }
+    }
+
+    @Test
+    public void testSessionExpiresAtMaxReauthTime() throws IOException {
+        String mechanism = OAuthBearerLoginModule.OAUTHBEARER_MECHANISM;
+        SaslServer saslServer = mock(SaslServer.class);
+        MockTime time = new MockTime(0, 1, 1000);
+        long maxReauthMs = 100L;
+        Duration tokenExpiryGreaterThanMaxReauth = Duration.ofMillis(maxReauthMs).multipliedBy(10);
+
+        try (
+                MockedStatic<?> ignored = mockSaslServer(saslServer, mechanism, time, tokenExpiryGreaterThanMaxReauth);
+                MockedStatic<?> ignored2 = mockKafkaPrincipal("[principal-type]", "[principal-name");
+                TransportLayer transportLayer = mockTransportLayer()
+        ) {
+
+            SaslServerAuthenticator authenticator = getSaslServerAuthenticatorForOAuth(mechanism, transportLayer, time, maxReauthMs);
+
+            mockRequest(saslHandshakeRequest(mechanism), transportLayer);
+            authenticator.authenticate();
+
+            when(saslServer.isComplete()).thenReturn(false).thenReturn(true);
+            mockRequest(saslAuthenticateRequest(), transportLayer);
+            authenticator.authenticate();
+
+            long atMaxReauthNanos = time.nanoseconds() + Duration.ofMillis(maxReauthMs).toNanos();
+            assertEquals(atMaxReauthNanos, authenticator.serverSessionExpirationTimeNanos());
+
+            ByteBuffer secondResponseSent = getResponses(transportLayer).get(1);
+            consumeSizeAndHeader(secondResponseSent);
+            SaslAuthenticateResponse response = SaslAuthenticateResponse.parse(secondResponseSent, (short) 2);
+            assertEquals(maxReauthMs, response.sessionLifetimeMs());
+        }
+    }
+
+    @Test
+    public void testSessionExpiresAtTokenExpiry() throws IOException {
+        String mechanism = OAuthBearerLoginModule.OAUTHBEARER_MECHANISM;
+        SaslServer saslServer = mock(SaslServer.class);
+        MockTime time = new MockTime(0, 1, 1000);
+        Duration tokenExpiryShorterThanMaxReauth = Duration.ofSeconds(2);
+        long maxReauthMs = tokenExpiryShorterThanMaxReauth.multipliedBy(2).toMillis();
+
+        try (
+                MockedStatic<?> ignored = mockSaslServer(saslServer, mechanism, time, tokenExpiryShorterThanMaxReauth);
+                MockedStatic<?> ignored2 = mockKafkaPrincipal("[principal-type]", "[principal-name");
+                TransportLayer transportLayer = mockTransportLayer()
+        ) {
+
+            SaslServerAuthenticator authenticator = getSaslServerAuthenticatorForOAuth(mechanism, transportLayer, time, maxReauthMs);
+
+            mockRequest(saslHandshakeRequest(mechanism), transportLayer);
+            authenticator.authenticate();
+
+            when(saslServer.isComplete()).thenReturn(false).thenReturn(true);
+            mockRequest(saslAuthenticateRequest(), transportLayer);
+            authenticator.authenticate();
+
+            long atTokenExpiryNanos = time.nanoseconds() + tokenExpiryShorterThanMaxReauth.toNanos();
+            assertEquals(atTokenExpiryNanos, authenticator.serverSessionExpirationTimeNanos());
+
+            ByteBuffer secondResponseSent = getResponses(transportLayer).get(1);
+            consumeSizeAndHeader(secondResponseSent);
+            SaslAuthenticateResponse response = SaslAuthenticateResponse.parse(secondResponseSent, (short) 2);
+            assertEquals(tokenExpiryShorterThanMaxReauth.toMillis(), response.sessionLifetimeMs());
+        }
+    }
+
+    private SaslServerAuthenticator getSaslServerAuthenticatorForOAuth(String mechanism, TransportLayer transportLayer, Time time, Long maxReauth) {
         Map<String, ?> configs = Collections.singletonMap(BrokerSecurityConfigs.SASL_ENABLED_MECHANISMS_CONFIG,
-            Collections.singletonList(SCRAM_SHA_256.mechanismName()));
+                Collections.singletonList(mechanism));
         ChannelMetadataRegistry metadataRegistry = new DefaultChannelMetadataRegistry();
-        SaslServerAuthenticator authenticator = setupAuthenticator(configs, transportLayer,
-            SCRAM_SHA_256.mechanismName(), metadataRegistry);
 
-        RequestHeader header = new RequestHeader(ApiKeys.API_VERSIONS, version, "clientId", 0);
+        return setupAuthenticator(configs, transportLayer, mechanism, metadataRegistry, time, maxReauth);
+    }
+
+    private MockedStatic<?> mockSaslServer(SaslServer saslServer, String mechanism, Time time, Duration tokenExpirationDuration) throws SaslException {
+        when(saslServer.getMechanismName()).thenReturn(mechanism);
+        when(saslServer.evaluateResponse(any())).thenReturn(new byte[]{});
+        long millisToExpiration = tokenExpirationDuration.toMillis();
+        when(saslServer.getNegotiatedProperty(eq(SaslInternalConfigs.CREDENTIAL_LIFETIME_MS_SASL_NEGOTIATED_PROPERTY_KEY)))
+                .thenReturn(time.milliseconds() + millisToExpiration);
+        return Mockito.mockStatic(Sasl.class, (Answer<SaslServer>) invocation -> saslServer);
+    }
+
+    private MockedStatic<?> mockKafkaPrincipal(String principalType, String name) {
+        KafkaPrincipalBuilder kafkaPrincipalBuilder = mock(KafkaPrincipalBuilder.class);
+        when(kafkaPrincipalBuilder.build(any())).thenReturn(new KafkaPrincipal(principalType, name));
+        MockedStatic<ChannelBuilders> channelBuilders = Mockito.mockStatic(ChannelBuilders.class, Answers.RETURNS_MOCKS);
+        channelBuilders.when(() ->
+                ChannelBuilders.createPrincipalBuilder(anyMap(), any(KerberosShortNamer.class), any(SslPrincipalMapper.class))
+        ).thenReturn(kafkaPrincipalBuilder);
+        return channelBuilders;
+    }
+
+    private void consumeSizeAndHeader(ByteBuffer responseBuffer) {
+        responseBuffer.getInt();
+        ResponseHeader.parse(responseBuffer, (short) 1);
+    }
+
+    private List<ByteBuffer> getResponses(TransportLayer transportLayer) throws IOException {
+        ArgumentCaptor<ByteBuffer[]> buffersCaptor = ArgumentCaptor.forClass(ByteBuffer[].class);
+        verify(transportLayer, times(2)).write(buffersCaptor.capture());
+        return buffersCaptor.getAllValues().stream()
+                .map(this::concatBuffers)
+                .collect(Collectors.toList());
+    }
+
+    private ByteBuffer concatBuffers(ByteBuffer[] buffers) {
+        int combinedCapacity = 0;
+        for (ByteBuffer buffer : buffers) {
+            combinedCapacity += buffer.capacity();
+        }
+        if (combinedCapacity > 0) {
+            ByteBuffer concat = ByteBuffer.allocate(combinedCapacity);
+            for (ByteBuffer buffer : buffers) {
+                concat.put(buffer);
+            }
+            return safeFlip(concat);
+        } else {
+            return ByteBuffer.allocate(0);
+        }
+    }
+
+    private ByteBuffer safeFlip(ByteBuffer buffer) {
+        return (ByteBuffer) ((Buffer) buffer).flip();
+    }
+
+    private SaslAuthenticateRequest saslAuthenticateRequest() {
+        SaslAuthenticateRequestData authenticateRequestData = new SaslAuthenticateRequestData();
+        return new SaslAuthenticateRequest.Builder(authenticateRequestData).build(ApiKeys.SASL_AUTHENTICATE.latestVersion());
+    }
+
+    private SaslHandshakeRequest saslHandshakeRequest(String mechanism) {
+        SaslHandshakeRequestData handshakeRequestData = new SaslHandshakeRequestData();
+        handshakeRequestData.setMechanism(mechanism);
+        return new SaslHandshakeRequest.Builder(handshakeRequestData).build(ApiKeys.SASL_HANDSHAKE.latestVersion());
+    }
+
+    private TransportLayer mockTransportLayer() throws IOException {
+        TransportLayer transportLayer = mock(TransportLayer.class, Answers.RETURNS_DEEP_STUBS);
+        when(transportLayer.socketChannel().socket().getInetAddress()).thenReturn(InetAddress.getLoopbackAddress());
+        when(transportLayer.write(any(ByteBuffer[].class))).thenReturn(Long.MAX_VALUE);
+        return transportLayer;
+    }
+
+    private void mockRequest(AbstractRequest request, TransportLayer transportLayer) throws IOException {
+        mockRequest(new RequestHeader(request.apiKey(), request.apiKey().latestVersion(), clientId, 0), request, transportLayer);
+    }
+
+    private void mockRequest(RequestHeader header, AbstractRequest request, TransportLayer transportLayer) throws IOException {
         ByteBuffer headerBuffer = RequestTestUtils.serializeRequestHeader(header);
 
-        ApiVersionsRequest request = new ApiVersionsRequest.Builder().build(version);
         ByteBuffer requestBuffer = request.serialize();
         requestBuffer.rewind();
 
-        when(transportLayer.socketChannel().socket().getInetAddress()).thenReturn(InetAddress.getLoopbackAddress());
-
         when(transportLayer.read(any(ByteBuffer.class))).then(invocation -> {
             invocation.<ByteBuffer>getArgument(0).putInt(headerBuffer.remaining() + requestBuffer.remaining());
             return 4;
         }).then(invocation -> {
             invocation.<ByteBuffer>getArgument(0)
-                .put(headerBuffer.duplicate())
-                .put(requestBuffer.duplicate());
+                    .put(headerBuffer.duplicate())
+                    .put(requestBuffer.duplicate());
             return headerBuffer.remaining() + requestBuffer.remaining();
         });
+    }
+
+    private void testApiVersionsRequest(short version, String expectedSoftwareName,
+                                        String expectedSoftwareVersion) throws IOException {
+        TransportLayer transportLayer = mockTransportLayer();
+        Map<String, ?> configs = Collections.singletonMap(BrokerSecurityConfigs.SASL_ENABLED_MECHANISMS_CONFIG,
+                Collections.singletonList(SCRAM_SHA_256.mechanismName()));
+        ChannelMetadataRegistry metadataRegistry = new DefaultChannelMetadataRegistry();
+        SaslServerAuthenticator authenticator = setupAuthenticator(configs, transportLayer, SCRAM_SHA_256.mechanismName(), metadataRegistry);
+
+        RequestHeader header = new RequestHeader(ApiKeys.API_VERSIONS, version, clientId, 0);
+        ApiVersionsRequest request = new ApiVersionsRequest.Builder().build(version);
+        mockRequest(header, request, transportLayer);
 
         authenticator.authenticate();
 
@@ -155,16 +364,24 @@ public class SaslServerAuthenticatorTest {
 
     private SaslServerAuthenticator setupAuthenticator(Map<String, ?> configs, TransportLayer transportLayer,
                                                        String mechanism, ChannelMetadataRegistry metadataRegistry) {
+        return setupAuthenticator(configs, transportLayer, mechanism, metadataRegistry, new MockTime(), null);
+    }
+
+    private SaslServerAuthenticator setupAuthenticator(Map<String, ?> configs, TransportLayer transportLayer,
+                                                       String mechanism, ChannelMetadataRegistry metadataRegistry, Time time, Long maxReauth) {
         TestJaasConfig jaasConfig = new TestJaasConfig();
-        jaasConfig.addEntry("jaasContext", PlainLoginModule.class.getName(), new HashMap<String, Object>());
+        jaasConfig.addEntry("jaasContext", PlainLoginModule.class.getName(), new HashMap<>());
         Map<String, Subject> subjects = Collections.singletonMap(mechanism, new Subject());
         Map<String, AuthenticateCallbackHandler> callbackHandlers = Collections.singletonMap(
                 mechanism, new SaslServerCallbackHandler());
         ApiVersionsResponse apiVersionsResponse = ApiVersionsResponse.defaultApiVersionsResponse(
-            ApiMessageType.ListenerType.ZK_BROKER);
+                ApiMessageType.ListenerType.ZK_BROKER);
+        Map<String, Long> connectionsMaxReauthMsByMechanism = maxReauth != null ?
+                Collections.singletonMap(mechanism, maxReauth) : Collections.emptyMap();
+
         return new SaslServerAuthenticator(configs, callbackHandlers, "node", subjects, null,
-                new ListenerName("ssl"), SecurityProtocol.SASL_SSL, transportLayer, Collections.emptyMap(),
-                metadataRegistry, Time.SYSTEM, () -> apiVersionsResponse);
+                new ListenerName("ssl"), SecurityProtocol.SASL_SSL, transportLayer, connectionsMaxReauthMsByMechanism,
+                metadataRegistry, time, () -> apiVersionsResponse);
     }
 
 }
diff --git a/clients/src/test/java/org/apache/kafka/common/security/oauthbearer/internals/expiring/ExpiringCredentialRefreshingLoginTest.java b/clients/src/test/java/org/apache/kafka/common/security/oauthbearer/internals/expiring/ExpiringCredentialRefreshingLoginTest.java
index 9a77c738d2..85f6622f09 100644
--- a/clients/src/test/java/org/apache/kafka/common/security/oauthbearer/internals/expiring/ExpiringCredentialRefreshingLoginTest.java
+++ b/clients/src/test/java/org/apache/kafka/common/security/oauthbearer/internals/expiring/ExpiringCredentialRefreshingLoginTest.java
@@ -48,6 +48,7 @@ import org.apache.kafka.common.utils.Time;
 import org.junit.jupiter.api.Test;
 import org.mockito.InOrder;
 import org.mockito.Mockito;
+import org.mockito.internal.util.MockUtil;
 
 public class ExpiringCredentialRefreshingLoginTest {
     private static final Configuration EMPTY_WILDCARD_CONFIGURATION;
@@ -188,8 +189,7 @@ public class ExpiringCredentialRefreshingLoginTest {
             super("contextName", null, null, EMPTY_WILDCARD_CONFIGURATION);
             this.testExpiringCredentialRefreshingLogin = Objects.requireNonNull(testExpiringCredentialRefreshingLogin);
             // sanity check to make sure it is likely a mock
-            if (Objects.requireNonNull(mockLoginContext).getClass().equals(LoginContext.class)
-                    || mockLoginContext.getClass().equals(getClass()))
+            if (!MockUtil.isMock(mockLoginContext))
                 throw new IllegalArgumentException();
             this.mockLoginContext = mockLoginContext;
         }
@@ -233,8 +233,7 @@ public class ExpiringCredentialRefreshingLoginTest {
         public void configure(LoginContext mockLoginContext,
                 TestExpiringCredentialRefreshingLogin testExpiringCredentialRefreshingLogin) throws LoginException {
             // sanity check to make sure it is likely a mock
-            if (Objects.requireNonNull(mockLoginContext).getClass().equals(LoginContext.class)
-                    || mockLoginContext.getClass().equals(TestLoginContext.class))
+            if (!MockUtil.isMock(mockLoginContext))
                 throw new IllegalArgumentException();
             this.testLoginContext = new TestLoginContext(Objects.requireNonNull(testExpiringCredentialRefreshingLogin),
                     mockLoginContext);