You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@iceberg.apache.org by dw...@apache.org on 2023/04/04 17:37:52 UTC

[iceberg] 02/06: AWS: Prevent token refresh scheduling on every sign request (#7270)

This is an automated email from the ASF dual-hosted git repository.

dweeks pushed a commit to branch 1.2.x
in repository https://gitbox.apache.org/repos/asf/iceberg.git

commit ffe875687024572e0c5e37caf98568c78c1f89da
Author: Eduard Tudenhöfner <et...@gmail.com>
AuthorDate: Mon Apr 3 19:42:10 2023 +0200

    AWS: Prevent token refresh scheduling on every sign request (#7270)
---
 .../aws/s3/signer/S3V4RestSignerClient.java        | 60 +++++++++++++++++-----
 .../iceberg/aws/s3/signer/S3SignerServlet.java     |  2 +
 .../iceberg/aws/s3/signer/TestS3RestSigner.java    | 52 ++++++++++++++-----
 3 files changed, 87 insertions(+), 27 deletions(-)

diff --git a/aws/src/main/java/org/apache/iceberg/aws/s3/signer/S3V4RestSignerClient.java b/aws/src/main/java/org/apache/iceberg/aws/s3/signer/S3V4RestSignerClient.java
index 032cb03c4f..a0920774af 100644
--- a/aws/src/main/java/org/apache/iceberg/aws/s3/signer/S3V4RestSignerClient.java
+++ b/aws/src/main/java/org/apache/iceberg/aws/s3/signer/S3V4RestSignerClient.java
@@ -20,7 +20,9 @@ package org.apache.iceberg.aws.s3.signer;
 
 import com.github.benmanes.caffeine.cache.Cache;
 import com.github.benmanes.caffeine.cache.Caffeine;
+import com.github.benmanes.caffeine.cache.RemovalListener;
 import java.net.URI;
+import java.time.Duration;
 import java.util.Collections;
 import java.util.List;
 import java.util.Map;
@@ -30,6 +32,7 @@ import java.util.function.Consumer;
 import java.util.function.Supplier;
 import javax.annotation.Nullable;
 import org.apache.iceberg.CatalogProperties;
+import org.apache.iceberg.relocated.com.google.common.annotations.VisibleForTesting;
 import org.apache.iceberg.relocated.com.google.common.base.Preconditions;
 import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap;
 import org.apache.iceberg.relocated.com.google.common.collect.Maps;
@@ -115,7 +118,8 @@ public abstract class S3V4RestSignerClient
         OAuth2Properties.TOKEN_REFRESH_ENABLED_DEFAULT);
   }
 
-  private ScheduledExecutorService tokenRefreshExecutor() {
+  @VisibleForTesting
+  ScheduledExecutorService tokenRefreshExecutor() {
     if (!keepTokenRefreshed()) {
       return null;
     }
@@ -131,6 +135,26 @@ public abstract class S3V4RestSignerClient
     return tokenRefreshExecutor;
   }
 
+  @Value.Lazy
+  Cache<String, AuthSession> authSessionCache() {
+    long expirationIntervalMs =
+        PropertyUtil.propertyAsLong(
+            properties(),
+            CatalogProperties.AUTH_SESSION_TIMEOUT_MS,
+            CatalogProperties.AUTH_SESSION_TIMEOUT_MS_DEFAULT);
+
+    return Caffeine.newBuilder()
+        .expireAfterAccess(Duration.ofMillis(expirationIntervalMs))
+        .removalListener(
+            (RemovalListener<String, AuthSession>)
+                (id, auth, cause) -> {
+                  if (null != auth) {
+                    auth.stopRefreshing();
+                  }
+                })
+        .build();
+  }
+
   private RESTClient httpClient() {
     if (null == httpClient) {
       synchronized (S3V4RestSignerClient.class) {
@@ -150,21 +174,31 @@ public abstract class S3V4RestSignerClient
   private AuthSession authSession() {
     String token = token().get();
     if (null != token) {
-      return AuthSession.fromAccessToken(
-          httpClient(),
-          tokenRefreshExecutor(),
-          token,
-          expiresAtMillis(properties()),
-          new AuthSession(ImmutableMap.of(), token, null, credential(), SCOPE));
+      return authSessionCache()
+          .get(
+              token,
+              id ->
+                  AuthSession.fromAccessToken(
+                      httpClient(),
+                      tokenRefreshExecutor(),
+                      token,
+                      expiresAtMillis(properties()),
+                      new AuthSession(ImmutableMap.of(), token, null, credential(), SCOPE)));
     }
 
     if (credentialProvided()) {
-      AuthSession session = new AuthSession(ImmutableMap.of(), null, null, credential(), SCOPE);
-      long startTimeMillis = System.currentTimeMillis();
-      OAuthTokenResponse authResponse =
-          OAuth2Util.fetchToken(httpClient(), session.headers(), credential(), SCOPE);
-      return AuthSession.fromTokenResponse(
-          httpClient(), tokenRefreshExecutor(), authResponse, startTimeMillis, session);
+      return authSessionCache()
+          .get(
+              credential(),
+              id -> {
+                AuthSession session =
+                    new AuthSession(ImmutableMap.of(), null, null, credential(), SCOPE);
+                long startTimeMillis = System.currentTimeMillis();
+                OAuthTokenResponse authResponse =
+                    OAuth2Util.fetchToken(httpClient(), session.headers(), credential(), SCOPE);
+                return AuthSession.fromTokenResponse(
+                    httpClient(), tokenRefreshExecutor(), authResponse, startTimeMillis, session);
+              });
     }
 
     return AuthSession.empty();
diff --git a/aws/src/test/java/org/apache/iceberg/aws/s3/signer/S3SignerServlet.java b/aws/src/test/java/org/apache/iceberg/aws/s3/signer/S3SignerServlet.java
index ad2177a01a..6240efa2ad 100644
--- a/aws/src/test/java/org/apache/iceberg/aws/s3/signer/S3SignerServlet.java
+++ b/aws/src/test/java/org/apache/iceberg/aws/s3/signer/S3SignerServlet.java
@@ -112,6 +112,7 @@ public class S3SignerServlet extends HttpServlet {
                 .withToken("client-credentials-token:sub=" + requestMap.get("client_id"))
                 .withIssuedTokenType("urn:ietf:params:oauth:token-type:access_token")
                 .withTokenType("Bearer")
+                .setExpirationInSeconds(100)
                 .build());
 
       case "urn:ietf:params:oauth:grant-type:token-exchange":
@@ -126,6 +127,7 @@ public class S3SignerServlet extends HttpServlet {
                 .withToken(token)
                 .withIssuedTokenType("urn:ietf:params:oauth:token-type:access_token")
                 .withTokenType("Bearer")
+                .setExpirationInSeconds(100)
                 .build());
 
       default:
diff --git a/aws/src/test/java/org/apache/iceberg/aws/s3/signer/TestS3RestSigner.java b/aws/src/test/java/org/apache/iceberg/aws/s3/signer/TestS3RestSigner.java
index 02914542b7..c19e3bc6d4 100644
--- a/aws/src/test/java/org/apache/iceberg/aws/s3/signer/TestS3RestSigner.java
+++ b/aws/src/test/java/org/apache/iceberg/aws/s3/signer/TestS3RestSigner.java
@@ -18,14 +18,16 @@
  */
 package org.apache.iceberg.aws.s3.signer;
 
+import static org.assertj.core.api.Assertions.assertThat;
+
 import java.nio.file.Paths;
 import java.util.List;
 import java.util.Map;
+import java.util.concurrent.ScheduledThreadPoolExecutor;
 import java.util.stream.Collectors;
 import org.apache.iceberg.aws.s3.MinioContainer;
 import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap;
 import org.apache.iceberg.rest.auth.OAuth2Properties;
-import org.assertj.core.api.Assertions;
 import org.eclipse.jetty.server.Server;
 import org.eclipse.jetty.server.handler.gzip.GzipHandler;
 import org.eclipse.jetty.servlet.ServletContextHandler;
@@ -33,6 +35,7 @@ import org.eclipse.jetty.servlet.ServletHolder;
 import org.jetbrains.annotations.NotNull;
 import org.junit.AfterClass;
 import org.junit.Before;
+import org.junit.BeforeClass;
 import org.junit.Rule;
 import org.junit.Test;
 import org.junit.rules.TemporaryFolder;
@@ -69,6 +72,7 @@ public class TestS3RestSigner {
           AwsBasicCredentials.create("accessKeyId", "secretAccessKey"));
 
   private static Server httpServer;
+  private static ValidatingSigner validatingSigner;
   private S3Client s3;
 
   @Rule public TemporaryFolder temp = new TemporaryFolder();
@@ -77,20 +81,13 @@ public class TestS3RestSigner {
   public MinioContainer minioContainer =
       new MinioContainer(CREDENTIALS_PROVIDER.resolveCredentials());
 
-  @AfterClass
-  public static void afterClass() throws Exception {
-    if (null != httpServer) {
-      httpServer.stop();
-    }
-  }
-
-  @Before
-  public void before() throws Exception {
+  @BeforeClass
+  public static void beforeClass() throws Exception {
     if (null == httpServer) {
       httpServer = initHttpServer();
     }
 
-    ValidatingSigner validatingSigner =
+    validatingSigner =
         new ValidatingSigner(
             ImmutableS3V4RestSignerClient.builder()
                 .properties(
@@ -101,7 +98,34 @@ public class TestS3RestSigner {
                         "catalog:12345"))
                 .build(),
             new CustomAwsS3V4Signer());
+  }
 
+  @AfterClass
+  public static void afterClass() throws Exception {
+    assertThat(validatingSigner.icebergSigner.tokenRefreshExecutor())
+        .isInstanceOf(ScheduledThreadPoolExecutor.class);
+
+    ScheduledThreadPoolExecutor executor =
+        ((ScheduledThreadPoolExecutor) validatingSigner.icebergSigner.tokenRefreshExecutor());
+    // token expiration is set to 100s so there should be exactly one token scheduled for refresh
+    assertThat(executor.getPoolSize()).isEqualTo(1);
+    assertThat(executor.getQueue())
+        .as("should only have a single token scheduled for refresh")
+        .hasSize(1);
+    assertThat(executor.getActiveCount())
+        .as("should not have any token being refreshed")
+        .isEqualTo(0);
+    assertThat(executor.getCompletedTaskCount())
+        .as("should not have any expired token that required a refresh")
+        .isEqualTo(0);
+
+    if (null != httpServer) {
+      httpServer.stop();
+    }
+  }
+
+  @Before
+  public void before() throws Exception {
     s3 =
         S3Client.builder()
             .region(REGION)
@@ -128,7 +152,7 @@ public class TestS3RestSigner {
         CreateMultipartUploadRequest.builder().bucket(BUCKET).key("random/multipart-key").build());
   }
 
-  private Server initHttpServer() throws Exception {
+  private static Server initHttpServer() throws Exception {
     S3SignerServlet servlet = new S3SignerServlet(S3ObjectMapper.mapper());
     ServletContextHandler servletContext =
         new ServletContextHandler(ServletContextHandler.NO_SESSIONS);
@@ -260,10 +284,10 @@ public class TestS3RestSigner {
 
       SdkHttpFullRequest awsResult = signWithAwsSigner(request, signerParams);
 
-      Assertions.assertThat(awsResult.headers().get("Authorization"))
+      assertThat(awsResult.headers().get("Authorization"))
           .isEqualTo(icebergResult.headers().get("Authorization"));
 
-      Assertions.assertThat(awsResult.headers()).isEqualTo(icebergResult.headers());
+      assertThat(awsResult.headers()).isEqualTo(icebergResult.headers());
       return awsResult;
     }