You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@pinot.apache.org by xi...@apache.org on 2024/02/16 08:24:19 UTC

(pinot) branch master updated: cache ssl contexts and reuse them (#12404)

This is an automated email from the ASF dual-hosted git repository.

xiangfu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/pinot.git


The following commit(s) were added to refs/heads/master by this push:
     new 24f48e3963 cache ssl contexts and reuse them (#12404)
24f48e3963 is described below

commit 24f48e39638be537b796a2c22084bdc60d8984a4
Author: Haitao Zhang <ha...@startree.ai>
AuthorDate: Fri Feb 16 00:24:12 2024 -0800

    cache ssl contexts and reuse them (#12404)
    
    * cache ssl contexts and reuse them
    
    * address comments
    
    * update java inline comment
    
    * add a comment
---
 .../org/apache/pinot/common/config/TlsConfig.java  | 79 ++--------------------
 .../pinot/common/utils/grpc/GrpcQueryClient.java   | 28 ++++++--
 .../core/transport/ChannelHandlerFactory.java      | 19 +++++-
 .../pinot/core/transport/grpc/GrpcQueryServer.java | 39 +++++++----
 4 files changed, 70 insertions(+), 95 deletions(-)

diff --git a/pinot-common/src/main/java/org/apache/pinot/common/config/TlsConfig.java b/pinot-common/src/main/java/org/apache/pinot/common/config/TlsConfig.java
index fc9344e96f..a5353a8e0c 100644
--- a/pinot-common/src/main/java/org/apache/pinot/common/config/TlsConfig.java
+++ b/pinot-common/src/main/java/org/apache/pinot/common/config/TlsConfig.java
@@ -20,12 +20,18 @@ package org.apache.pinot.common.config;
 
 import io.netty.handler.ssl.SslProvider;
 import java.security.KeyStore;
+import lombok.EqualsAndHashCode;
+import lombok.Getter;
+import lombok.Setter;
 import org.apache.commons.lang3.StringUtils;
 
 
 /**
  * Container object for TLS/SSL configuration of pinot clients and servers (netty, grizzly, etc.)
  */
+@Getter
+@Setter
+@EqualsAndHashCode
 public class TlsConfig {
   private boolean _clientAuthEnabled;
   private String _keyStoreType = KeyStore.getDefaultType();
@@ -35,6 +41,7 @@ public class TlsConfig {
   private String _trustStorePath;
   private String _trustStorePassword;
   private String _sslProvider = SslProvider.JDK.toString();
+  // If true, the client will not verify the server's certificate
   private boolean _insecure = false;
 
   public TlsConfig() {
@@ -52,79 +59,7 @@ public class TlsConfig {
     _sslProvider = tlsConfig._sslProvider;
   }
 
-  public boolean isClientAuthEnabled() {
-    return _clientAuthEnabled;
-  }
-
-  public void setClientAuthEnabled(boolean clientAuthEnabled) {
-    _clientAuthEnabled = clientAuthEnabled;
-  }
-
-  public String getKeyStoreType() {
-    return _keyStoreType;
-  }
-
-  public void setKeyStoreType(String keyStoreType) {
-    _keyStoreType = keyStoreType;
-  }
-
-  public String getKeyStorePath() {
-    return _keyStorePath;
-  }
-
-  public void setKeyStorePath(String keyStorePath) {
-    _keyStorePath = keyStorePath;
-  }
-
-  public String getKeyStorePassword() {
-    return _keyStorePassword;
-  }
-
-  public void setKeyStorePassword(String keyStorePassword) {
-    _keyStorePassword = keyStorePassword;
-  }
-
-  public String getTrustStoreType() {
-    return _trustStoreType;
-  }
-
-  public void setTrustStoreType(String trustStoreType) {
-    _trustStoreType = trustStoreType;
-  }
-
-  public String getTrustStorePath() {
-    return _trustStorePath;
-  }
-
-  public void setTrustStorePath(String trustStorePath) {
-    _trustStorePath = trustStorePath;
-  }
-
-  public String getTrustStorePassword() {
-    return _trustStorePassword;
-  }
-
-  public void setTrustStorePassword(String trustStorePassword) {
-    _trustStorePassword = trustStorePassword;
-  }
-
-  public String getSslProvider() {
-    return _sslProvider;
-  }
-
-  public void setSslProvider(String sslProvider) {
-    _sslProvider = sslProvider;
-  }
-
   public boolean isCustomized() {
     return StringUtils.isNoneBlank(_keyStorePath) || StringUtils.isNoneBlank(_trustStorePath);
   }
-
-  public boolean isInsecure() {
-    return _insecure;
-  }
-
-  public void setInsecure(boolean insecure) {
-    _insecure = insecure;
-  }
 }
diff --git a/pinot-common/src/main/java/org/apache/pinot/common/utils/grpc/GrpcQueryClient.java b/pinot-common/src/main/java/org/apache/pinot/common/utils/grpc/GrpcQueryClient.java
index c560a39b3a..94621fa176 100644
--- a/pinot-common/src/main/java/org/apache/pinot/common/utils/grpc/GrpcQueryClient.java
+++ b/pinot-common/src/main/java/org/apache/pinot/common/utils/grpc/GrpcQueryClient.java
@@ -22,10 +22,13 @@ import io.grpc.ManagedChannel;
 import io.grpc.ManagedChannelBuilder;
 import io.grpc.netty.shaded.io.grpc.netty.GrpcSslContexts;
 import io.grpc.netty.shaded.io.grpc.netty.NettyChannelBuilder;
+import io.grpc.netty.shaded.io.netty.handler.ssl.SslContext;
 import io.grpc.netty.shaded.io.netty.handler.ssl.SslContextBuilder;
 import io.grpc.netty.shaded.io.netty.handler.ssl.SslProvider;
 import java.util.Collections;
 import java.util.Iterator;
+import java.util.Map;
+import java.util.concurrent.ConcurrentHashMap;
 import java.util.concurrent.TimeUnit;
 import javax.net.ssl.SSLException;
 import nl.altindag.ssl.SSLFactory;
@@ -41,6 +44,10 @@ import org.slf4j.LoggerFactory;
 public class GrpcQueryClient {
   private static final Logger LOGGER = LoggerFactory.getLogger(GrpcQueryClient.class);
   private static final int DEFAULT_CHANNEL_SHUTDOWN_TIMEOUT_SECOND = 10;
+  // the key is the hashCode of the TlsConfig, the value is the SslContext
+  // We don't use TlsConfig as the map key because the TlsConfig is mutable, which means the hashCode can change. If the
+  // hashCode changes and the map is resized, the SslContext of the old hashCode will be lost.
+  private static final Map<Integer, SslContext> CLIENT_SSL_CONTEXTS_CACHE = new ConcurrentHashMap<>();
 
   private final ManagedChannel _managedChannel;
   private final PinotQueryServerGrpc.PinotQueryServerBlockingStub _blockingStub;
@@ -55,8 +62,17 @@ public class GrpcQueryClient {
           ManagedChannelBuilder.forAddress(host, port).maxInboundMessageSize(config.getMaxInboundMessageSizeBytes())
               .usePlaintext().build();
     } else {
+      _managedChannel =
+          NettyChannelBuilder.forAddress(host, port).maxInboundMessageSize(config.getMaxInboundMessageSizeBytes())
+              .sslContext(buildSslContext(config.getTlsConfig())).build();
+    }
+    _blockingStub = PinotQueryServerGrpc.newBlockingStub(_managedChannel);
+  }
+
+  private SslContext buildSslContext(TlsConfig tlsConfig) {
+    LOGGER.info("Building gRPC SSL context");
+    SslContext sslContext = CLIENT_SSL_CONTEXTS_CACHE.computeIfAbsent(tlsConfig.hashCode(), tlsConfigHashCode -> {
       try {
-        TlsConfig tlsConfig = config.getTlsConfig();
         SSLFactory sslFactory = TlsUtils.createSSLFactory(tlsConfig);
         if (TlsUtils.isKeyOrTrustStorePathNullOrHasFileScheme(tlsConfig.getKeyStorePath())
             && TlsUtils.isKeyOrTrustStorePathNullOrHasFileScheme(tlsConfig.getTrustStorePath())) {
@@ -71,14 +87,12 @@ public class GrpcQueryClient {
         } else {
           sslContextBuilder = GrpcSslContexts.configure(sslContextBuilder);
         }
-        _managedChannel =
-            NettyChannelBuilder.forAddress(host, port).maxInboundMessageSize(config.getMaxInboundMessageSizeBytes())
-                .sslContext(sslContextBuilder.build()).build();
+        return sslContextBuilder.build();
       } catch (SSLException e) {
-        throw new RuntimeException("Failed to create Netty gRPC channel with SSL Context", e);
+        throw new RuntimeException("Failed to build gRPC SSL context", e);
       }
-    }
-    _blockingStub = PinotQueryServerGrpc.newBlockingStub(_managedChannel);
+    });
+    return sslContext;
   }
 
   public Iterator<Server.ServerResponse> submit(Server.ServerRequest request) {
diff --git a/pinot-core/src/main/java/org/apache/pinot/core/transport/ChannelHandlerFactory.java b/pinot-core/src/main/java/org/apache/pinot/core/transport/ChannelHandlerFactory.java
index d38e53957e..aaa68018f3 100644
--- a/pinot-core/src/main/java/org/apache/pinot/core/transport/ChannelHandlerFactory.java
+++ b/pinot-core/src/main/java/org/apache/pinot/core/transport/ChannelHandlerFactory.java
@@ -22,6 +22,8 @@ import io.netty.channel.ChannelHandler;
 import io.netty.channel.socket.SocketChannel;
 import io.netty.handler.codec.LengthFieldBasedFrameDecoder;
 import io.netty.handler.codec.LengthFieldPrepender;
+import io.netty.handler.ssl.SslContext;
+import java.util.Map;
 import java.util.concurrent.ConcurrentHashMap;
 import org.apache.pinot.common.config.TlsConfig;
 import org.apache.pinot.common.metrics.BrokerMetrics;
@@ -36,8 +38,15 @@ import org.apache.pinot.spi.env.PinotConfiguration;
  * The {@code ChannelHandlerFactory} provides all kinds of Netty ChannelHandlers
  */
 public class ChannelHandlerFactory {
-
   public static final String SSL = "ssl";
+  // The key is the hashCode of the TlsConfig, the value is the SslContext
+  // We don't use TlsConfig as the map key because the TlsConfig is mutable, which means the hashCode can change. If the
+  // hashCode changes and the map is resized, the SslContext of the old hashCode will be lost.
+  private static final Map<Integer, SslContext> CLIENT_SSL_CONTEXTS_CACHE = new ConcurrentHashMap<>();
+  // the key is the hashCode of the TlsConfig, the value is the SslContext
+  // We don't use TlsConfig as the map key because the TlsConfig is mutable, which means the hashCode can change. If the
+  // hashCode changes and the map is resized, the SslContext of the old hashCode will be lost.
+  private static final Map<Integer, SslContext> SERVER_SSL_CONTEXTS_CACHE = new ConcurrentHashMap<>();
 
   private ChannelHandlerFactory() {
   }
@@ -61,14 +70,18 @@ public class ChannelHandlerFactory {
    * The {@code getClientTlsHandler} return a Client side Tls handler that encrypt and decrypt everything.
    */
   public static ChannelHandler getClientTlsHandler(TlsConfig tlsConfig, SocketChannel ch) {
-    return TlsUtils.buildClientContext(tlsConfig).newHandler(ch.alloc());
+    SslContext sslContext = CLIENT_SSL_CONTEXTS_CACHE
+        .computeIfAbsent(tlsConfig.hashCode(), tlsConfigHashCode -> TlsUtils.buildClientContext(tlsConfig));
+    return sslContext.newHandler(ch.alloc());
   }
 
   /**
    * The {@code getServerTlsHandler} return a Server side Tls handler that encrypt and decrypt everything.
    */
   public static ChannelHandler getServerTlsHandler(TlsConfig tlsConfig, SocketChannel ch) {
-    return TlsUtils.buildServerContext(tlsConfig).newHandler(ch.alloc());
+    SslContext sslContext = SERVER_SSL_CONTEXTS_CACHE.computeIfAbsent(
+        tlsConfig.hashCode(), tlsConfigHashCode -> TlsUtils.buildServerContext(tlsConfig));
+    return sslContext.newHandler(ch.alloc());
   }
 
   /**
diff --git a/pinot-core/src/main/java/org/apache/pinot/core/transport/grpc/GrpcQueryServer.java b/pinot-core/src/main/java/org/apache/pinot/core/transport/grpc/GrpcQueryServer.java
index 44c69eeab6..0b9621e1e1 100644
--- a/pinot-core/src/main/java/org/apache/pinot/core/transport/grpc/GrpcQueryServer.java
+++ b/pinot-core/src/main/java/org/apache/pinot/core/transport/grpc/GrpcQueryServer.java
@@ -29,6 +29,8 @@ import io.grpc.netty.shaded.io.netty.handler.ssl.SslContextBuilder;
 import io.grpc.netty.shaded.io.netty.handler.ssl.SslProvider;
 import io.grpc.stub.StreamObserver;
 import java.io.IOException;
+import java.util.Map;
+import java.util.concurrent.ConcurrentHashMap;
 import java.util.concurrent.ExecutorService;
 import java.util.concurrent.Executors;
 import nl.altindag.ssl.SSLFactory;
@@ -56,6 +58,10 @@ import org.slf4j.LoggerFactory;
 // TODO: Plug in QueryScheduler
 public class GrpcQueryServer extends PinotQueryServerGrpc.PinotQueryServerImplBase {
   private static final Logger LOGGER = LoggerFactory.getLogger(GrpcQueryServer.class);
+  // the key is the hashCode of the TlsConfig, the value is the SslContext
+  // We don't use TlsConfig as the map key because the TlsConfig is mutable, which means the hashCode can change. If the
+  // hashCode changes and the map is resized, the SslContext of the old hashCode will be lost.
+  private static final Map<Integer, SslContext> SERVER_SSL_CONTEXTS_CACHE = new ConcurrentHashMap<>();
 
   private final QueryExecutor _queryExecutor;
   private final ServerMetrics _serverMetrics;
@@ -85,23 +91,30 @@ public class GrpcQueryServer extends PinotQueryServerGrpc.PinotQueryServerImplBa
   }
 
   private SslContext buildGRpcSslContext(TlsConfig tlsConfig)
-      throws Exception {
+      throws IllegalArgumentException {
     LOGGER.info("Building gRPC SSL context");
     if (tlsConfig.getKeyStorePath() == null) {
       throw new IllegalArgumentException("Must provide key store path for secured gRpc server");
     }
-    SSLFactory sslFactory = TlsUtils.createSSLFactory(tlsConfig);
-    if (TlsUtils.isKeyOrTrustStorePathNullOrHasFileScheme(tlsConfig.getKeyStorePath())
-        && TlsUtils.isKeyOrTrustStorePathNullOrHasFileScheme(tlsConfig.getTrustStorePath())) {
-      TlsUtils.enableAutoRenewalFromFileStoreForSSLFactory(sslFactory, tlsConfig);
-    }
-    SslContextBuilder sslContextBuilder = SslContextBuilder.forServer(sslFactory.getKeyManagerFactory().get())
-        .sslProvider(SslProvider.valueOf(tlsConfig.getSslProvider()));
-    sslFactory.getTrustManagerFactory().ifPresent(sslContextBuilder::trustManager);
-    if (tlsConfig.isClientAuthEnabled()) {
-      sslContextBuilder.clientAuth(ClientAuth.REQUIRE);
-    }
-    return GrpcSslContexts.configure(sslContextBuilder).build();
+    SslContext sslContext = SERVER_SSL_CONTEXTS_CACHE.computeIfAbsent(tlsConfig.hashCode(), tlsConfigHashCode -> {
+      try {
+        SSLFactory sslFactory = TlsUtils.createSSLFactory(tlsConfig);
+        if (TlsUtils.isKeyOrTrustStorePathNullOrHasFileScheme(tlsConfig.getKeyStorePath())
+            && TlsUtils.isKeyOrTrustStorePathNullOrHasFileScheme(tlsConfig.getTrustStorePath())) {
+          TlsUtils.enableAutoRenewalFromFileStoreForSSLFactory(sslFactory, tlsConfig);
+        }
+        SslContextBuilder sslContextBuilder = SslContextBuilder.forServer(sslFactory.getKeyManagerFactory().get())
+            .sslProvider(SslProvider.valueOf(tlsConfig.getSslProvider()));
+        sslFactory.getTrustManagerFactory().ifPresent(sslContextBuilder::trustManager);
+        if (tlsConfig.isClientAuthEnabled()) {
+          sslContextBuilder.clientAuth(ClientAuth.REQUIRE);
+        }
+        return GrpcSslContexts.configure(sslContextBuilder).build();
+      } catch (Exception e) {
+        throw new RuntimeException("Failed to build gRPC SSL context", e);
+      }
+    });
+    return sslContext;
   }
 
   public void start() {


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@pinot.apache.org
For additional commands, e-mail: commits-help@pinot.apache.org