You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@pulsar.apache.org by mm...@apache.org on 2023/02/22 22:11:24 UTC

[pulsar] 02/03: [fix][broker] Make authentication refresh threadsafe (#19506)

This is an automated email from the ASF dual-hosted git repository.

mmarshall pushed a commit to branch branch-2.10
in repository https://gitbox.apache.org/repos/asf/pulsar.git

commit 26e10536773498cfd0a4514256456795658dd6d8
Author: Michael Marshall <mm...@apache.org>
AuthorDate: Tue Feb 14 03:09:55 2023 -0600

    [fix][broker] Make authentication refresh threadsafe  (#19506)
    
    Co-authored-by: Lari Hotari <lh...@users.noreply.github.com>
    (cherry picked from commit 153e4d4cc3b56aaee224b0a68e0186c08125c975)
    (cherry picked from commit 161ec5aa20c4e0d9f82473e43e5ccdc7a113f236)
---
 .../broker/service/PulsarChannelInitializer.java   |  29 ------
 .../apache/pulsar/broker/service/ServerCnx.java    | 109 +++++++++++++--------
 .../pulsar/broker/service/ServerCnxTest.java       |  14 ++-
 3 files changed, 76 insertions(+), 76 deletions(-)

diff --git a/pulsar-broker/src/main/java/org/apache/pulsar/broker/service/PulsarChannelInitializer.java b/pulsar-broker/src/main/java/org/apache/pulsar/broker/service/PulsarChannelInitializer.java
index e75c518a50f..e1057de54cc 100644
--- a/pulsar-broker/src/main/java/org/apache/pulsar/broker/service/PulsarChannelInitializer.java
+++ b/pulsar-broker/src/main/java/org/apache/pulsar/broker/service/PulsarChannelInitializer.java
@@ -18,9 +18,6 @@
  */
 package org.apache.pulsar.broker.service;
 
-import static org.apache.bookkeeper.util.SafeRunnable.safeRun;
-import com.github.benmanes.caffeine.cache.Cache;
-import com.github.benmanes.caffeine.cache.Caffeine;
 import com.google.common.annotations.VisibleForTesting;
 import io.netty.channel.ChannelInitializer;
 import io.netty.channel.socket.SocketChannel;
@@ -29,8 +26,6 @@ import io.netty.handler.flow.FlowControlHandler;
 import io.netty.handler.ssl.SslContext;
 import io.netty.handler.ssl.SslHandler;
 import io.netty.handler.ssl.SslProvider;
-import java.net.SocketAddress;
-import java.util.concurrent.TimeUnit;
 import lombok.Builder;
 import lombok.Data;
 import lombok.extern.slf4j.Slf4j;
@@ -56,14 +51,6 @@ public class PulsarChannelInitializer extends ChannelInitializer<SocketChannel>
     private final ServiceConfiguration brokerConf;
     private NettySSLContextAutoRefreshBuilder nettySSLContextAutoRefreshBuilder;
 
-    // This cache is used to maintain a list of active connections to iterate over them
-    // We keep weak references to have the cache to be auto cleaned up when the connections
-    // objects are GCed.
-    private final Cache<SocketAddress, ServerCnx> connections = Caffeine.newBuilder()
-            .weakKeys()
-            .weakValues()
-            .build();
-
     /**
      * @param pulsar
      *              An instance of {@link PulsarService}
@@ -112,10 +99,6 @@ public class PulsarChannelInitializer extends ChannelInitializer<SocketChannel>
             this.sslCtxRefresher = null;
         }
         this.brokerConf = pulsar.getConfiguration();
-
-        pulsar.getExecutor().scheduleAtFixedRate(safeRun(this::refreshAuthenticationCredentials),
-                pulsar.getConfig().getAuthenticationRefreshCheckSeconds(),
-                pulsar.getConfig().getAuthenticationRefreshCheckSeconds(), TimeUnit.SECONDS);
     }
 
     @Override
@@ -145,18 +128,6 @@ public class PulsarChannelInitializer extends ChannelInitializer<SocketChannel>
         ch.pipeline().addLast("flowController", new FlowControlHandler());
         ServerCnx cnx = newServerCnx(pulsar, listenerName);
         ch.pipeline().addLast("handler", cnx);
-
-        connections.put(ch.remoteAddress(), cnx);
-    }
-
-    private void refreshAuthenticationCredentials() {
-        connections.asMap().values().forEach(cnx -> {
-            try {
-                cnx.refreshAuthenticationCredentials();
-            } catch (Throwable t) {
-                log.warn("[{}] Failed to refresh auth credentials", cnx.clientAddress());
-            }
-        });
     }
 
     @VisibleForTesting
diff --git a/pulsar-broker/src/main/java/org/apache/pulsar/broker/service/ServerCnx.java b/pulsar-broker/src/main/java/org/apache/pulsar/broker/service/ServerCnx.java
index 851058dc811..954ab1d182f 100644
--- a/pulsar-broker/src/main/java/org/apache/pulsar/broker/service/ServerCnx.java
+++ b/pulsar-broker/src/main/java/org/apache/pulsar/broker/service/ServerCnx.java
@@ -37,6 +37,7 @@ import io.netty.handler.codec.haproxy.HAProxyMessage;
 import io.netty.handler.ssl.SslHandler;
 import io.netty.util.concurrent.FastThreadLocal;
 import io.netty.util.concurrent.Promise;
+import io.netty.util.concurrent.ScheduledFuture;
 import io.prometheus.client.Gauge;
 import java.io.IOException;
 import java.net.InetSocketAddress;
@@ -177,6 +178,7 @@ public class ServerCnx extends PulsarHandler implements TransportCnx {
     private AuthenticationState originalAuthState;
     private volatile AuthenticationDataSource originalAuthData;
     private boolean pendingAuthChallengeResponse = false;
+    private ScheduledFuture<?> authRefreshTask;
 
     // Max number of pending requests per connections. If multiple producers are sharing the same connection the flow
     // control done by a single producer might not be enough to prevent write spikes on the broker.
@@ -306,6 +308,9 @@ public class ServerCnx extends PulsarHandler implements TransportCnx {
         }
 
         cnxsPerThread.get().remove(this);
+        if (authRefreshTask != null) {
+            authRefreshTask.cancel(false);
+        }
 
         // Connection is gone, close the producers immediately
         producers.forEach((__, producerFuture) -> {
@@ -656,15 +661,19 @@ public class ServerCnx extends PulsarHandler implements TransportCnx {
 
             if (state != State.Connected) {
                 // First time authentication is done
-                if (service.isAuthenticationEnabled() && service.isAuthorizationEnabled()) {
-                    if (!service.getAuthorizationService()
-                            .isValidOriginalPrincipal(this.authRole, originalPrincipal, remoteAddress)) {
-                        state = State.Failed;
-                        service.getPulsarStats().recordConnectionCreateFail();
-                        final ByteBuf msg = Commands.newError(-1, ServerError.AuthorizationError, "Invalid roles.");
-                        ctx.writeAndFlush(msg).addListener(ChannelFutureListener.CLOSE);
-                        return;
+                if (service.isAuthenticationEnabled()) {
+                    if (service.isAuthorizationEnabled()) {
+                        if (!service.getAuthorizationService()
+                                .isValidOriginalPrincipal(this.authRole, originalPrincipal, remoteAddress)) {
+                            state = State.Failed;
+                            service.getPulsarStats().recordConnectionCreateFail();
+                            final ByteBuf msg = Commands.newError(-1, ServerError.AuthorizationError,
+                                    "Invalid roles.");
+                            ctx.writeAndFlush(msg).addListener(ChannelFutureListener.CLOSE);
+                            return;
+                        }
                     }
+                    maybeScheduleAuthenticationCredentialsRefresh();
                 }
                 completeConnect(clientProtocolVersion, clientVersion);
             } else {
@@ -691,61 +700,75 @@ public class ServerCnx extends PulsarHandler implements TransportCnx {
         }
     }
 
-    public void refreshAuthenticationCredentials() {
-        AuthenticationState authState = this.originalAuthState != null ? originalAuthState : this.authState;
-
+    /**
+     * Method to initialize the {@link #authRefreshTask} task.
+     */
+    private void maybeScheduleAuthenticationCredentialsRefresh() {
+        assert ctx.executor().inEventLoop();
+        assert authRefreshTask == null;
         if (authState == null) {
             // Authentication is disabled or there's no local state to refresh
             return;
-        } else if (getState() != State.Connected || !isActive) {
-            // Connection is either still being established or already closed.
+        }
+        authRefreshTask = ctx.executor().scheduleAtFixedRate(this::refreshAuthenticationCredentials,
+                service.getPulsar().getConfig().getAuthenticationRefreshCheckSeconds(),
+                service.getPulsar().getConfig().getAuthenticationRefreshCheckSeconds(),
+                TimeUnit.SECONDS);
+    }
+
+    private void refreshAuthenticationCredentials() {
+        assert ctx.executor().inEventLoop();
+        AuthenticationState authState = this.originalAuthState != null ? originalAuthState : this.authState;
+        if (getState() == State.Failed) {
+            // Happens when an exception is thrown that causes this connection to close.
             return;
         } else if (!authState.isExpired()) {
             // Credentials are still valid. Nothing to do at this point
             return;
         } else if (originalPrincipal != null && originalAuthState == null) {
+            // This case is only checked when the authState is expired because we've reached a point where
+            // authentication needs to be refreshed, but the protocol does not support it unless the proxy forwards
+            // the originalAuthData.
             log.info(
                     "[{}] Cannot revalidate user credential when using proxy and"
                             + " not forwarding the credentials. Closing connection",
                     remoteAddress);
+            ctx.close();
             return;
         }
 
-        ctx.executor().execute(SafeRun.safeRun(() -> {
-            log.info("[{}] Refreshing authentication credentials for originalPrincipal {} and authRole {}",
-                    remoteAddress, originalPrincipal, this.authRole);
-
-            if (!supportsAuthenticationRefresh()) {
-                log.warn("[{}] Closing connection because client doesn't support auth credentials refresh",
-                        remoteAddress);
-                ctx.close();
-                return;
-            }
+        if (!supportsAuthenticationRefresh()) {
+            log.warn("[{}] Closing connection because client doesn't support auth credentials refresh",
+                    remoteAddress);
+            ctx.close();
+            return;
+        }
 
-            if (pendingAuthChallengeResponse) {
-                log.warn("[{}] Closing connection after timeout on refreshing auth credentials",
-                        remoteAddress);
-                ctx.close();
-                return;
-            }
+        if (pendingAuthChallengeResponse) {
+            log.warn("[{}] Closing connection after timeout on refreshing auth credentials",
+                    remoteAddress);
+            ctx.close();
+            return;
+        }
 
-            try {
-                AuthData brokerData = authState.refreshAuthentication();
+        log.info("[{}] Refreshing authentication credentials for originalPrincipal {} and authRole {}",
+                remoteAddress, originalPrincipal, this.authRole);
+        try {
+            AuthData brokerData = authState.refreshAuthentication();
 
-                ctx.writeAndFlush(Commands.newAuthChallenge(authMethod, brokerData,
-                        getRemoteEndpointProtocolVersion()));
-                if (log.isDebugEnabled()) {
-                    log.debug("[{}] Sent auth challenge to client to refresh credentials with method: {}.",
-                        remoteAddress, authMethod);
-                }
+            ctx.writeAndFlush(Commands.newAuthChallenge(authMethod, brokerData,
+                    getRemoteEndpointProtocolVersion()));
+            if (log.isDebugEnabled()) {
+                log.debug("[{}] Sent auth challenge to client to refresh credentials with method: {}.",
+                    remoteAddress, authMethod);
+            }
 
-                pendingAuthChallengeResponse = true;
+            pendingAuthChallengeResponse = true;
 
-            } catch (AuthenticationException e) {
-                log.warn("[{}] Failed to refresh authentication: {}", remoteAddress, e);
-                ctx.close();
-            }
-        }));
+        } catch (AuthenticationException e) {
+            log.warn("[{}] Failed to refresh authentication: {}", remoteAddress, e);
+            ctx.close();
+        }
     }
 
     private static final byte[] emptyArray = new byte[0];
diff --git a/pulsar-broker/src/test/java/org/apache/pulsar/broker/service/ServerCnxTest.java b/pulsar-broker/src/test/java/org/apache/pulsar/broker/service/ServerCnxTest.java
index a994a3adbad..c39e1f5b7e4 100644
--- a/pulsar-broker/src/test/java/org/apache/pulsar/broker/service/ServerCnxTest.java
+++ b/pulsar-broker/src/test/java/org/apache/pulsar/broker/service/ServerCnxTest.java
@@ -496,10 +496,13 @@ public class ServerCnxTest {
         when(brokerService.getAuthenticationService()).thenReturn(authenticationService);
         when(authenticationService.getAuthenticationProvider(authMethodName)).thenReturn(authenticationProvider);
         svcConfig.setAuthenticationEnabled(true);
+        svcConfig.setAuthenticationRefreshCheckSeconds(30);
 
         resetChannel();
         assertTrue(channel.isActive());
         assertEquals(serverCnx.getState(), State.Start);
+        // Don't want the keep alive task affecting which messages are handled
+        serverCnx.cancelKeepAliveTask();
 
         ByteBuf clientCommand = Commands.newConnect(authMethodName, "pass.client", "");
         channel.writeInbound(clientCommand);
@@ -512,7 +515,7 @@ public class ServerCnxTest {
 
         // Trigger the ServerCnx to check if authentication is expired (it is because of our special implementation)
         // and then force channel to run the task
-        serverCnx.refreshAuthenticationCredentials();
+        channel.advanceTimeBy(30, TimeUnit.SECONDS);
         channel.runPendingTasks();
         Object responseAuthChallenge1 = getResponse();
         assertTrue(responseAuthChallenge1 instanceof CommandAuthChallenge);
@@ -522,7 +525,7 @@ public class ServerCnxTest {
         channel.writeInbound(authResponse1);
 
         // Trigger the ServerCnx to check if authentication is expired again
-        serverCnx.refreshAuthenticationCredentials();
+        channel.advanceTimeBy(30, TimeUnit.SECONDS);
         assertTrue(channel.hasPendingTasks(), "This test assumes there are pending tasks to run.");
         channel.runPendingTasks();
         Object responseAuthChallenge2 = getResponse();
@@ -548,10 +551,13 @@ public class ServerCnxTest {
         svcConfig.setAuthenticationEnabled(true);
         svcConfig.setAuthenticateOriginalAuthData(true);
         svcConfig.setProxyRoles(Collections.singleton("pass.proxy"));
+        svcConfig.setAuthenticationRefreshCheckSeconds(30);
 
         resetChannel();
         assertTrue(channel.isActive());
         assertEquals(serverCnx.getState(), State.Start);
+        // Don't want the keep alive task affecting which messages are handled
+        serverCnx.cancelKeepAliveTask();
 
         ByteBuf clientCommand = Commands.newConnect(authMethodName, "pass.proxy", 1, null,
                 null, "pass.client", "pass.client", authMethodName);
@@ -568,7 +574,7 @@ public class ServerCnxTest {
 
         // Trigger the ServerCnx to check if authentication is expired (it is because of our special implementation)
         // and then force channel to run the task
-        serverCnx.refreshAuthenticationCredentials();
+        channel.advanceTimeBy(30, TimeUnit.SECONDS);
         assertTrue(channel.hasPendingTasks(), "This test assumes there are pending tasks to run.");
         channel.runPendingTasks();
         Object responseAuthChallenge1 = getResponse();
@@ -579,7 +585,7 @@ public class ServerCnxTest {
         channel.writeInbound(authResponse1);
 
         // Trigger the ServerCnx to check if authentication is expired again
-        serverCnx.refreshAuthenticationCredentials();
+        channel.advanceTimeBy(30, TimeUnit.SECONDS);
         channel.runPendingTasks();
         Object responseAuthChallenge2 = getResponse();
         assertTrue(responseAuthChallenge2 instanceof CommandAuthChallenge);