You are viewing a plain text version of this content. The canonical link for it is here.
Posted to notifications@shardingsphere.apache.org by pa...@apache.org on 2023/04/27 02:32:27 UTC
[shardingsphere] branch master updated: ShardingSphere-Proxy frontend supports SSL/TLS (#25337)
This is an automated email from the ASF dual-hosted git repository.
panjuan pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/shardingsphere.git
The following commit(s) were added to refs/heads/master by this push:
new 668525f6a0a ShardingSphere-Proxy frontend supports SSL/TLS (#25337)
668525f6a0a is described below
commit 668525f6a0a4e21904e2249d04a8438cd9637c37
Author: 吴伟杰 <wu...@apache.org>
AuthorDate: Thu Apr 27 10:32:18 2023 +0800
ShardingSphere-Proxy frontend supports SSL/TLS (#25337)
* Proxy frontend supports SSL/TLS
* Add log about key and certificate
* Add SSL props to server.yaml
* Complete ShowDistVariablesExecutorTest
* Remove unused log in MySQLPacketCodecEngine
* Add MySQLSSLRequestHandlerTest
* Add PostgreSQLSSLWillingPacketTest
* Complete MySQLHandshakePacketTest
* Complete MySQLAuthenticationEngineTest
* Register BouncyCastle provider and add SSLUtilsTest
* Add ProxySSLContextTest
* Fix checkstyle
* Add PostgreSQLAuthenticationEngineTest
* Add OpenGaussAuthenticationEngineTest
---
.../mysql/codec/MySQLPacketCodecEngine.java | 2 +-
.../packet/handshake/MySQLHandshakePacket.java | 4 +-
.../packet/handshake/MySQLHandshakePacketTest.java | 16 ++-
...cket.java => PostgreSQLSSLUnwillingPacket.java} | 4 +-
...Packet.java => PostgreSQLSSLWillingPacket.java} | 6 +-
....java => PostgreSQLSSLUnwillingPacketTest.java} | 4 +-
...st.java => PostgreSQLSSLWillingPacketTest.java} | 16 +--
features/encrypt/plugin/sm/pom.xml | 6 +-
.../config/props/ConfigurationPropertyKey.java | 27 +++-
.../client/netty/MySQLNegotiateHandlerTest.java | 2 +-
pom.xml | 1 +
.../queryable/ShowDistVariablesExecutorTest.java | 2 +-
.../org/apache/shardingsphere/proxy/Bootstrap.java | 2 +
.../bootstrap/src/main/resources/conf/server.yaml | 6 +
proxy/frontend/core/pom.xml | 20 +++
.../proxy/frontend/ssl/ProxySSLContext.java | 110 +++++++++++++++++
.../proxy/frontend/ssl/SSLUtils.java | 78 ++++++++++++
.../proxy/frontend/ssl/ProxySSLContextTest.java | 137 +++++++++++++++++++++
.../proxy/frontend/ssl/SSLUtilsTest.java | 70 +++++++++++
.../authentication/MySQLAuthenticationEngine.java | 8 +-
.../frontend/mysql/ssl/MySQLSSLRequestHandler.java | 60 +++++++++
.../MySQLAuthenticationEngineTest.java | 18 ++-
.../mysql/ssl/MySQLSSLRequestHandlerTest.java | 77 ++++++++++++
.../OpenGaussAuthenticationEngine.java | 13 +-
.../OpenGaussAuthenticationEngineTest.java | 33 ++++-
.../PostgreSQLAuthenticationEngine.java | 13 +-
.../PostgreSQLAuthenticationEngineTest.java | 28 ++++-
27 files changed, 718 insertions(+), 45 deletions(-)
diff --git a/db-protocol/mysql/src/main/java/org/apache/shardingsphere/db/protocol/mysql/codec/MySQLPacketCodecEngine.java b/db-protocol/mysql/src/main/java/org/apache/shardingsphere/db/protocol/mysql/codec/MySQLPacketCodecEngine.java
index 905238bfa81..f2d412c1fe3 100644
--- a/db-protocol/mysql/src/main/java/org/apache/shardingsphere/db/protocol/mysql/codec/MySQLPacketCodecEngine.java
+++ b/db-protocol/mysql/src/main/java/org/apache/shardingsphere/db/protocol/mysql/codec/MySQLPacketCodecEngine.java
@@ -20,8 +20,8 @@ package org.apache.shardingsphere.db.protocol.mysql.codec;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.CompositeByteBuf;
import io.netty.channel.ChannelHandlerContext;
-import org.apache.shardingsphere.db.protocol.constant.CommonConstants;
import org.apache.shardingsphere.db.protocol.codec.DatabasePacketCodecEngine;
+import org.apache.shardingsphere.db.protocol.constant.CommonConstants;
import org.apache.shardingsphere.db.protocol.mysql.constant.MySQLConstants;
import org.apache.shardingsphere.db.protocol.mysql.packet.MySQLPacket;
import org.apache.shardingsphere.db.protocol.mysql.packet.generic.MySQLErrPacket;
diff --git a/db-protocol/mysql/src/main/java/org/apache/shardingsphere/db/protocol/mysql/packet/handshake/MySQLHandshakePacket.java b/db-protocol/mysql/src/main/java/org/apache/shardingsphere/db/protocol/mysql/packet/handshake/MySQLHandshakePacket.java
index ae178afa084..6c1aa9ef31d 100644
--- a/db-protocol/mysql/src/main/java/org/apache/shardingsphere/db/protocol/mysql/packet/handshake/MySQLHandshakePacket.java
+++ b/db-protocol/mysql/src/main/java/org/apache/shardingsphere/db/protocol/mysql/packet/handshake/MySQLHandshakePacket.java
@@ -52,10 +52,10 @@ public final class MySQLHandshakePacket implements MySQLPacket {
private String authPluginName;
- public MySQLHandshakePacket(final int connectionId, final MySQLAuthenticationPluginData authPluginData) {
+ public MySQLHandshakePacket(final int connectionId, final boolean sslEnabled, final MySQLAuthenticationPluginData authPluginData) {
serverVersion = MySQLServerInfo.getDefaultServerVersion();
this.connectionId = connectionId;
- capabilityFlagsLower = MySQLCapabilityFlag.calculateHandshakeCapabilityFlagsLower();
+ capabilityFlagsLower = MySQLCapabilityFlag.calculateHandshakeCapabilityFlagsLower() | (sslEnabled ? MySQLCapabilityFlag.CLIENT_SSL.getValue() : 0);
characterSet = MySQLServerInfo.DEFAULT_CHARSET.getId();
statusFlag = MySQLStatusFlag.SERVER_STATUS_AUTOCOMMIT;
capabilityFlagsUpper = MySQLCapabilityFlag.calculateHandshakeCapabilityFlagsUpper();
diff --git a/db-protocol/mysql/src/test/java/org/apache/shardingsphere/db/protocol/mysql/packet/handshake/MySQLHandshakePacketTest.java b/db-protocol/mysql/src/test/java/org/apache/shardingsphere/db/protocol/mysql/packet/handshake/MySQLHandshakePacketTest.java
index c3bd5c8b6df..b21861d819a 100644
--- a/db-protocol/mysql/src/test/java/org/apache/shardingsphere/db/protocol/mysql/packet/handshake/MySQLHandshakePacketTest.java
+++ b/db-protocol/mysql/src/test/java/org/apache/shardingsphere/db/protocol/mysql/packet/handshake/MySQLHandshakePacketTest.java
@@ -83,10 +83,22 @@ class MySQLHandshakePacketTest {
assertThat(actual.getAuthPluginName(), is(MySQLAuthenticationMethod.NATIVE.getMethodName()));
}
+ @Test
+ void assertNewWithSSLEnabled() {
+ MySQLHandshakePacket actual = new MySQLHandshakePacket(1, true, new MySQLAuthenticationPluginData());
+ assertThat(actual.getCapabilityFlagsLower() & MySQLCapabilityFlag.CLIENT_SSL.getValue(), is(MySQLCapabilityFlag.CLIENT_SSL.getValue()));
+ }
+
+ @Test
+ void assertNewWithSSLNotEnabled() {
+ MySQLHandshakePacket actual = new MySQLHandshakePacket(1, false, new MySQLAuthenticationPluginData());
+ assertThat(actual.getCapabilityFlagsLower() & MySQLCapabilityFlag.CLIENT_SSL.getValue(), is(0));
+ }
+
@Test
void assertWrite() {
MySQLAuthenticationPluginData authPluginData = new MySQLAuthenticationPluginData(part1, part2);
- new MySQLHandshakePacket(1000, authPluginData).write(payload);
+ new MySQLHandshakePacket(1000, false, authPluginData).write(payload);
verify(payload).writeInt1(MySQLServerInfo.PROTOCOL_VERSION);
verify(payload).writeStringNul(MySQLServerInfo.getDefaultServerVersion());
verify(payload).writeInt4(1000);
@@ -103,7 +115,7 @@ class MySQLHandshakePacketTest {
@Test
void assertWriteWithClientPluginAuth() {
MySQLAuthenticationPluginData authPluginData = new MySQLAuthenticationPluginData(part1, part2);
- MySQLHandshakePacket actual = new MySQLHandshakePacket(1000, authPluginData);
+ MySQLHandshakePacket actual = new MySQLHandshakePacket(1000, false, authPluginData);
actual.setAuthPluginName(MySQLAuthenticationMethod.NATIVE);
actual.write(payload);
verify(payload).writeInt1(MySQLServerInfo.PROTOCOL_VERSION);
diff --git a/db-protocol/postgresql/src/main/java/org/apache/shardingsphere/db/protocol/postgresql/packet/handshake/PostgreSQLSSLNegativePacket.java b/db-protocol/postgresql/src/main/java/org/apache/shardingsphere/db/protocol/postgresql/packet/handshake/PostgreSQLSSLUnwillingPacket.java
similarity index 91%
copy from db-protocol/postgresql/src/main/java/org/apache/shardingsphere/db/protocol/postgresql/packet/handshake/PostgreSQLSSLNegativePacket.java
copy to db-protocol/postgresql/src/main/java/org/apache/shardingsphere/db/protocol/postgresql/packet/handshake/PostgreSQLSSLUnwillingPacket.java
index 4ba8a68f528..83b40c1db60 100644
--- a/db-protocol/postgresql/src/main/java/org/apache/shardingsphere/db/protocol/postgresql/packet/handshake/PostgreSQLSSLNegativePacket.java
+++ b/db-protocol/postgresql/src/main/java/org/apache/shardingsphere/db/protocol/postgresql/packet/handshake/PostgreSQLSSLUnwillingPacket.java
@@ -21,9 +21,9 @@ import org.apache.shardingsphere.db.protocol.postgresql.packet.PostgreSQLPacket;
import org.apache.shardingsphere.db.protocol.postgresql.payload.PostgreSQLPacketPayload;
/**
- * SSL negative packet for PostgreSQL.
+ * SSL unwilling packet for PostgreSQL.
*/
-public final class PostgreSQLSSLNegativePacket implements PostgreSQLPacket {
+public final class PostgreSQLSSLUnwillingPacket implements PostgreSQLPacket {
private static final char STATUS_CODE = 'N';
diff --git a/db-protocol/postgresql/src/main/java/org/apache/shardingsphere/db/protocol/postgresql/packet/handshake/PostgreSQLSSLNegativePacket.java b/db-protocol/postgresql/src/main/java/org/apache/shardingsphere/db/protocol/postgresql/packet/handshake/PostgreSQLSSLWillingPacket.java
similarity index 87%
rename from db-protocol/postgresql/src/main/java/org/apache/shardingsphere/db/protocol/postgresql/packet/handshake/PostgreSQLSSLNegativePacket.java
rename to db-protocol/postgresql/src/main/java/org/apache/shardingsphere/db/protocol/postgresql/packet/handshake/PostgreSQLSSLWillingPacket.java
index 4ba8a68f528..c543072916a 100644
--- a/db-protocol/postgresql/src/main/java/org/apache/shardingsphere/db/protocol/postgresql/packet/handshake/PostgreSQLSSLNegativePacket.java
+++ b/db-protocol/postgresql/src/main/java/org/apache/shardingsphere/db/protocol/postgresql/packet/handshake/PostgreSQLSSLWillingPacket.java
@@ -21,11 +21,11 @@ import org.apache.shardingsphere.db.protocol.postgresql.packet.PostgreSQLPacket;
import org.apache.shardingsphere.db.protocol.postgresql.payload.PostgreSQLPacketPayload;
/**
- * SSL negative packet for PostgreSQL.
+ * SSL willing packet for PostgreSQL.
*/
-public final class PostgreSQLSSLNegativePacket implements PostgreSQLPacket {
+public final class PostgreSQLSSLWillingPacket implements PostgreSQLPacket {
- private static final char STATUS_CODE = 'N';
+ private static final char STATUS_CODE = 'S';
@Override
public void write(final PostgreSQLPacketPayload payload) {
diff --git a/db-protocol/postgresql/src/test/java/org/apache/shardingsphere/db/protocol/postgresql/packet/handshake/PostgreSQLSSLNegativePacketTest.java b/db-protocol/postgresql/src/test/java/org/apache/shardingsphere/db/protocol/postgresql/packet/handshake/PostgreSQLSSLUnwillingPacketTest.java
similarity index 92%
copy from db-protocol/postgresql/src/test/java/org/apache/shardingsphere/db/protocol/postgresql/packet/handshake/PostgreSQLSSLNegativePacketTest.java
copy to db-protocol/postgresql/src/test/java/org/apache/shardingsphere/db/protocol/postgresql/packet/handshake/PostgreSQLSSLUnwillingPacketTest.java
index 5e15c5d735a..3795e250e2d 100644
--- a/db-protocol/postgresql/src/test/java/org/apache/shardingsphere/db/protocol/postgresql/packet/handshake/PostgreSQLSSLNegativePacketTest.java
+++ b/db-protocol/postgresql/src/test/java/org/apache/shardingsphere/db/protocol/postgresql/packet/handshake/PostgreSQLSSLUnwillingPacketTest.java
@@ -27,13 +27,13 @@ import java.nio.charset.StandardCharsets;
import static org.hamcrest.CoreMatchers.is;
import static org.hamcrest.MatcherAssert.assertThat;
-class PostgreSQLSSLNegativePacketTest {
+class PostgreSQLSSLUnwillingPacketTest {
@Test
void assertReadWrite() {
ByteBuf byteBuf = ByteBufTestUtils.createByteBuf(1);
PostgreSQLPacketPayload payload = new PostgreSQLPacketPayload(byteBuf, StandardCharsets.UTF_8);
- PostgreSQLSSLNegativePacket packet = new PostgreSQLSSLNegativePacket();
+ PostgreSQLSSLUnwillingPacket packet = new PostgreSQLSSLUnwillingPacket();
packet.write(payload);
assertThat(byteBuf.writerIndex(), is(1));
assertThat(payload.readInt1(), is((int) 'N'));
diff --git a/db-protocol/postgresql/src/test/java/org/apache/shardingsphere/db/protocol/postgresql/packet/handshake/PostgreSQLSSLNegativePacketTest.java b/db-protocol/postgresql/src/test/java/org/apache/shardingsphere/db/protocol/postgresql/packet/handshake/PostgreSQLSSLWillingPacketTest.java
similarity index 67%
rename from db-protocol/postgresql/src/test/java/org/apache/shardingsphere/db/protocol/postgresql/packet/handshake/PostgreSQLSSLNegativePacketTest.java
rename to db-protocol/postgresql/src/test/java/org/apache/shardingsphere/db/protocol/postgresql/packet/handshake/PostgreSQLSSLWillingPacketTest.java
index 5e15c5d735a..778e56bb649 100644
--- a/db-protocol/postgresql/src/test/java/org/apache/shardingsphere/db/protocol/postgresql/packet/handshake/PostgreSQLSSLNegativePacketTest.java
+++ b/db-protocol/postgresql/src/test/java/org/apache/shardingsphere/db/protocol/postgresql/packet/handshake/PostgreSQLSSLWillingPacketTest.java
@@ -17,8 +17,7 @@
package org.apache.shardingsphere.db.protocol.postgresql.packet.handshake;
-import io.netty.buffer.ByteBuf;
-import org.apache.shardingsphere.db.protocol.postgresql.packet.ByteBufTestUtils;
+import io.netty.buffer.Unpooled;
import org.apache.shardingsphere.db.protocol.postgresql.payload.PostgreSQLPacketPayload;
import org.junit.jupiter.api.Test;
@@ -27,15 +26,12 @@ import java.nio.charset.StandardCharsets;
import static org.hamcrest.CoreMatchers.is;
import static org.hamcrest.MatcherAssert.assertThat;
-class PostgreSQLSSLNegativePacketTest {
+class PostgreSQLSSLWillingPacketTest {
@Test
- void assertReadWrite() {
- ByteBuf byteBuf = ByteBufTestUtils.createByteBuf(1);
- PostgreSQLPacketPayload payload = new PostgreSQLPacketPayload(byteBuf, StandardCharsets.UTF_8);
- PostgreSQLSSLNegativePacket packet = new PostgreSQLSSLNegativePacket();
- packet.write(payload);
- assertThat(byteBuf.writerIndex(), is(1));
- assertThat(payload.readInt1(), is((int) 'N'));
+ void assertWrite() {
+ byte[] actual = new byte[1];
+ new PostgreSQLSSLWillingPacket().write(new PostgreSQLPacketPayload(Unpooled.wrappedBuffer(actual).writerIndex(0), StandardCharsets.UTF_8));
+ assertThat(actual[0], is((byte) 'S'));
}
}
diff --git a/features/encrypt/plugin/sm/pom.xml b/features/encrypt/plugin/sm/pom.xml
index 7d1486738e5..5f3d4e11b66 100644
--- a/features/encrypt/plugin/sm/pom.xml
+++ b/features/encrypt/plugin/sm/pom.xml
@@ -27,10 +27,6 @@
<artifactId>shardingsphere-encrypt-sm</artifactId>
<name>${project.artifactId}</name>
- <properties>
- <bcprov-jdk15on.version>1.70</bcprov-jdk15on.version>
- </properties>
-
<dependencies>
<dependency>
<groupId>org.apache.shardingsphere</groupId>
@@ -54,7 +50,7 @@
<dependency>
<groupId>org.bouncycastle</groupId>
<artifactId>bcprov-jdk15on</artifactId>
- <version>${bcprov-jdk15on.version}</version>
+ <version>${bouncycastle.version}</version>
</dependency>
</dependencies>
</project>
diff --git a/infra/common/src/main/java/org/apache/shardingsphere/infra/config/props/ConfigurationPropertyKey.java b/infra/common/src/main/java/org/apache/shardingsphere/infra/config/props/ConfigurationPropertyKey.java
index a30cc91dd5c..7e25bd4ea2d 100644
--- a/infra/common/src/main/java/org/apache/shardingsphere/infra/config/props/ConfigurationPropertyKey.java
+++ b/infra/common/src/main/java/org/apache/shardingsphere/infra/config/props/ConfigurationPropertyKey.java
@@ -116,7 +116,32 @@ public enum ConfigurationPropertyKey implements TypedPropertyKey {
/**
* CDC server port.
*/
- CDC_SERVER_PORT("cdc-server-port", "33071", int.class, true);
+ CDC_SERVER_PORT("cdc-server-port", "33071", int.class, true),
+
+ /**
+ * Proxy frontend SSL enabled.
+ */
+ PROXY_FRONTEND_SSL_ENABLED("proxy-frontend-ssl-enabled", String.valueOf(Boolean.FALSE), boolean.class, true),
+
+ /**
+ * Proxy frontend SSL certificate file.
+ */
+ PROXY_FRONTEND_SSL_CERT_FILE("proxy-frontend-ssl-cert-file", "", String.class, true),
+
+ /**
+ * Proxy frontend SSL private key file.
+ */
+ PROXY_FRONTEND_SSL_KEY_FILE("proxy-frontend-ssl-key-file", "", String.class, true),
+
+ /**
+ * Proxy frontend SSL protocol version.
+ */
+ PROXY_FRONTEND_SSL_VERSION("proxy-frontend-ssl-version", "TLSv1.2,TLSv1.3", String.class, true),
+
+ /**
+ * Proxy frontend SSL cipher.
+ */
+ PROXY_FRONTEND_SSL_CIPHER("proxy-frontend-ssl-cipher", "", String.class, true);
private final String key;
diff --git a/kernel/data-pipeline/dialect/mysql/src/test/java/org/apache/shardingsphere/data/pipeline/mysql/ingest/client/netty/MySQLNegotiateHandlerTest.java b/kernel/data-pipeline/dialect/mysql/src/test/java/org/apache/shardingsphere/data/pipeline/mysql/ingest/client/netty/MySQLNegotiateHandlerTest.java
index 186aef8c3df..e37ed9672f6 100644
--- a/kernel/data-pipeline/dialect/mysql/src/test/java/org/apache/shardingsphere/data/pipeline/mysql/ingest/client/netty/MySQLNegotiateHandlerTest.java
+++ b/kernel/data-pipeline/dialect/mysql/src/test/java/org/apache/shardingsphere/data/pipeline/mysql/ingest/client/netty/MySQLNegotiateHandlerTest.java
@@ -77,7 +77,7 @@ class MySQLNegotiateHandlerTest {
@Test
void assertChannelReadHandshakeInitPacket() throws ReflectiveOperationException {
- MySQLHandshakePacket handshakePacket = new MySQLHandshakePacket(0, new MySQLAuthenticationPluginData(new byte[8], new byte[12]));
+ MySQLHandshakePacket handshakePacket = new MySQLHandshakePacket(0, false, new MySQLAuthenticationPluginData(new byte[8], new byte[12]));
handshakePacket.setAuthPluginName(MySQLAuthenticationMethod.NATIVE);
mysqlNegotiateHandler.channelRead(channelHandlerContext, handshakePacket);
verify(channel).writeAndFlush(ArgumentMatchers.any(MySQLHandshakeResponse41Packet.class));
diff --git a/pom.xml b/pom.xml
index a742967cf6b..306e16ee45b 100644
--- a/pom.xml
+++ b/pom.xml
@@ -86,6 +86,7 @@
<calcite.version>1.32.0</calcite.version>
<netty.version>4.1.90.Final</netty.version>
+ <bouncycastle.version>1.70</bouncycastle.version>
<javax.transaction.version>1.1</javax.transaction.version>
diff --git a/proxy/backend/core/src/test/java/org/apache/shardingsphere/proxy/backend/handler/distsql/ral/queryable/ShowDistVariablesExecutorTest.java b/proxy/backend/core/src/test/java/org/apache/shardingsphere/proxy/backend/handler/distsql/ral/queryable/ShowDistVariablesExecutorTest.java
index 38a108d0520..13196355585 100644
--- a/proxy/backend/core/src/test/java/org/apache/shardingsphere/proxy/backend/handler/distsql/ral/queryable/ShowDistVariablesExecutorTest.java
+++ b/proxy/backend/core/src/test/java/org/apache/shardingsphere/proxy/backend/handler/distsql/ral/queryable/ShowDistVariablesExecutorTest.java
@@ -63,7 +63,7 @@ class ShowDistVariablesExecutorTest {
when(metaData.getGlobalRuleMetaData()).thenReturn(new ShardingSphereRuleMetaData(Collections.singleton(new LoggingRule(new DefaultLoggingRuleConfigurationBuilder().build()))));
ShowDistVariablesExecutor executor = new ShowDistVariablesExecutor();
Collection<LocalDataQueryResultRow> actual = executor.getRows(metaData, connectionSession, mock(ShowDistVariablesStatement.class));
- assertThat(actual.size(), is(21));
+ assertThat(actual.size(), is(26));
LocalDataQueryResultRow row = actual.iterator().next();
assertThat(row.getCell(1), is("agent_plugins_enabled"));
assertThat(row.getCell(2), is("true"));
diff --git a/proxy/bootstrap/src/main/java/org/apache/shardingsphere/proxy/Bootstrap.java b/proxy/bootstrap/src/main/java/org/apache/shardingsphere/proxy/Bootstrap.java
index eb4edabe538..38b67e1ef0b 100644
--- a/proxy/bootstrap/src/main/java/org/apache/shardingsphere/proxy/Bootstrap.java
+++ b/proxy/bootstrap/src/main/java/org/apache/shardingsphere/proxy/Bootstrap.java
@@ -26,6 +26,7 @@ import org.apache.shardingsphere.proxy.backend.config.ProxyConfigurationLoader;
import org.apache.shardingsphere.proxy.backend.config.YamlProxyConfiguration;
import org.apache.shardingsphere.proxy.frontend.CDCServer;
import org.apache.shardingsphere.proxy.frontend.ShardingSphereProxy;
+import org.apache.shardingsphere.proxy.frontend.ssl.ProxySSLContext;
import org.apache.shardingsphere.proxy.initializer.BootstrapInitializer;
import java.io.IOException;
@@ -54,6 +55,7 @@ public final class Bootstrap {
new BootstrapInitializer().init(yamlConfig, port, bootstrapArgs.isForce());
Optional.ofNullable((Integer) yamlConfig.getServerConfiguration().getProps().get(ConfigurationPropertyKey.CDC_SERVER_PORT.getKey()))
.ifPresent(cdcPort -> new CDCServer(addresses, cdcPort).start());
+ ProxySSLContext.init();
ShardingSphereProxy shardingSphereProxy = new ShardingSphereProxy();
bootstrapArgs.getSocketPath().ifPresent(shardingSphereProxy::start);
shardingSphereProxy.start(port, addresses);
diff --git a/proxy/bootstrap/src/main/resources/conf/server.yaml b/proxy/bootstrap/src/main/resources/conf/server.yaml
index 5cb0a59792b..0d1e1b2115f 100644
--- a/proxy/bootstrap/src/main/resources/conf/server.yaml
+++ b/proxy/bootstrap/src/main/resources/conf/server.yaml
@@ -82,3 +82,9 @@
# proxy-default-port: 3307 # Proxy default port.
# proxy-netty-backlog: 1024 # Proxy netty backlog.
# cdc-server-port: 33071 # CDC server port
+# proxy-frontend-ssl-enabled: false
+# # When certificate and private key files not set, Proxy will generate a temporary RSA key pair and a self-signed X.509 certificate.
+# proxy-frontend-ssl-cert-file: ''
+# proxy-frontend-ssl-key-file: ''
+# proxy-frontend-ssl-cipher: ''
+# proxy-frontend-ssl-version: TLSv1.2,TLSv1.3
diff --git a/proxy/frontend/core/pom.xml b/proxy/frontend/core/pom.xml
index 0c3c2512e56..a59c14f008d 100644
--- a/proxy/frontend/core/pom.xml
+++ b/proxy/frontend/core/pom.xml
@@ -27,6 +27,10 @@
<artifactId>shardingsphere-proxy-frontend-core</artifactId>
<name>${project.artifactId}</name>
+ <properties>
+ <netty.tcnative.version>2.0.59.Final</netty.tcnative.version>
+ </properties>
+
<dependencies>
<dependency>
<groupId>org.apache.shardingsphere</groupId>
@@ -90,5 +94,21 @@
<artifactId>netty-transport-native-epoll</artifactId>
<classifier>linux-aarch_64</classifier>
</dependency>
+ <dependency>
+ <groupId>io.netty</groupId>
+ <artifactId>netty-tcnative-boringssl-static</artifactId>
+ <version>${netty.tcnative.version}</version>
+ <scope>runtime</scope>
+ </dependency>
+ <dependency>
+ <groupId>org.bouncycastle</groupId>
+ <artifactId>bcpkix-jdk15on</artifactId>
+ <version>${bouncycastle.version}</version>
+ </dependency>
+ <dependency>
+ <groupId>org.bouncycastle</groupId>
+ <artifactId>bctls-jdk15on</artifactId>
+ <version>${bouncycastle.version}</version>
+ </dependency>
</dependencies>
</project>
diff --git a/proxy/frontend/core/src/main/java/org/apache/shardingsphere/proxy/frontend/ssl/ProxySSLContext.java b/proxy/frontend/core/src/main/java/org/apache/shardingsphere/proxy/frontend/ssl/ProxySSLContext.java
new file mode 100644
index 00000000000..65bb89f3dbd
--- /dev/null
+++ b/proxy/frontend/core/src/main/java/org/apache/shardingsphere/proxy/frontend/ssl/ProxySSLContext.java
@@ -0,0 +1,110 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.shardingsphere.proxy.frontend.ssl;
+
+import com.google.common.base.Preconditions;
+import io.netty.buffer.ByteBufAllocator;
+import io.netty.handler.ssl.SslContext;
+import io.netty.handler.ssl.SslContextBuilder;
+import lombok.extern.slf4j.Slf4j;
+import org.apache.shardingsphere.infra.config.props.ConfigurationPropertyKey;
+import org.apache.shardingsphere.proxy.backend.context.ProxyContext;
+
+import javax.net.ssl.SSLEngine;
+import javax.net.ssl.SSLException;
+import java.nio.file.Paths;
+import java.security.KeyPair;
+import java.security.cert.X509Certificate;
+import java.util.Arrays;
+
+/**
+ * Proxy SSL context.
+ */
+@Slf4j
+public final class ProxySSLContext {
+
+ private static final ProxySSLContext INSTANCE = new ProxySSLContext();
+
+ private SslContext sslContext;
+
+ /**
+ * Init SSL context.
+ *
+ * @throws SSLException SSL exception
+ */
+ public static void init() throws SSLException {
+ if (!ProxyContext.getInstance().getContextManager().getMetaDataContexts().getMetaData().getProps().<Boolean>getValue(ConfigurationPropertyKey.PROXY_FRONTEND_SSL_ENABLED)) {
+ log.info("Proxy frontend SSL/TLS is not enabled.");
+ return;
+ }
+ String certFile = ProxyContext.getInstance().getContextManager().getMetaDataContexts().getMetaData().getProps().<String>getValue(ConfigurationPropertyKey.PROXY_FRONTEND_SSL_CERT_FILE).trim();
+ String keyFile = ProxyContext.getInstance().getContextManager().getMetaDataContexts().getMetaData().getProps().<String>getValue(ConfigurationPropertyKey.PROXY_FRONTEND_SSL_KEY_FILE).trim();
+ SslContextBuilder sslContextBuilder;
+ if (isUserProvidedCertificate(certFile, keyFile)) {
+ sslContextBuilder = SslContextBuilder.forServer(Paths.get(certFile).toFile(), Paths.get(keyFile).toFile());
+ log.info("Using X.509 certificate chain file [{}] and private key file [{}]", certFile, keyFile);
+ } else {
+ KeyPair keyPair = SSLUtils.generateRSAKeyPair();
+ X509Certificate x509Certificate = SSLUtils.generateSelfSignedX509Certificate(keyPair);
+ sslContextBuilder = SslContextBuilder.forServer(keyPair.getPrivate(), x509Certificate);
+ log.warn("RSA key pair and CA certificate are generated by ShardingSphere-Proxy and self-signed.");
+ }
+ String versions = ProxyContext.getInstance().getContextManager().getMetaDataContexts().getMetaData().getProps().<String>getValue(ConfigurationPropertyKey.PROXY_FRONTEND_SSL_VERSION).trim();
+ sslContextBuilder.protocols(versions.split(","));
+ String ciphers = ProxyContext.getInstance().getContextManager().getMetaDataContexts().getMetaData().getProps().<String>getValue(ConfigurationPropertyKey.PROXY_FRONTEND_SSL_CIPHER).trim();
+ if (!ciphers.isEmpty()) {
+ sslContextBuilder.ciphers(Arrays.asList(ciphers.split(",")));
+ }
+ INSTANCE.sslContext = sslContextBuilder.build();
+ log.info("Proxy frontend SSL/TLS is enabled. Supported protocols: {}", versions);
+ }
+
+ private static boolean isUserProvidedCertificate(final String keyFile, final String certFile) {
+ Preconditions.checkArgument(keyFile.isEmpty() && certFile.isEmpty() || !keyFile.isEmpty() && !certFile.isEmpty(), "%s and %s must be either both empty or both non-empty",
+ ConfigurationPropertyKey.PROXY_FRONTEND_SSL_KEY_FILE.getKey(), ConfigurationPropertyKey.PROXY_FRONTEND_SSL_CERT_FILE.getKey());
+ return !keyFile.isEmpty();
+ }
+
+ /**
+ * Get instance of {@link ProxySSLContext}.
+ *
+ * @return instance of {@link ProxySSLContext}
+ */
+ public static ProxySSLContext getInstance() {
+ return INSTANCE;
+ }
+
+ /**
+ * Is SSL enabled.
+ *
+ * @return is SSL enabled
+ */
+ public boolean isSSLEnabled() {
+ return null != sslContext;
+ }
+
+ /**
+ * Create a new {@link SSLEngine}.
+ *
+ * @param allocator allocator
+ * @return a new {@link SSLEngine}
+ */
+ public SSLEngine newSSLEngine(final ByteBufAllocator allocator) {
+ return sslContext.newEngine(allocator);
+ }
+}
diff --git a/proxy/frontend/core/src/main/java/org/apache/shardingsphere/proxy/frontend/ssl/SSLUtils.java b/proxy/frontend/core/src/main/java/org/apache/shardingsphere/proxy/frontend/ssl/SSLUtils.java
new file mode 100644
index 00000000000..8bd9db1e359
--- /dev/null
+++ b/proxy/frontend/core/src/main/java/org/apache/shardingsphere/proxy/frontend/ssl/SSLUtils.java
@@ -0,0 +1,78 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.shardingsphere.proxy.frontend.ssl;
+
+import lombok.AccessLevel;
+import lombok.NoArgsConstructor;
+import lombok.SneakyThrows;
+import org.bouncycastle.asn1.x500.X500Name;
+import org.bouncycastle.asn1.x500.X500NameBuilder;
+import org.bouncycastle.asn1.x500.style.BCStyle;
+import org.bouncycastle.cert.X509v3CertificateBuilder;
+import org.bouncycastle.cert.jcajce.JcaX509CertificateConverter;
+import org.bouncycastle.cert.jcajce.JcaX509v3CertificateBuilder;
+import org.bouncycastle.jce.provider.BouncyCastleProvider;
+import org.bouncycastle.operator.ContentSigner;
+import org.bouncycastle.operator.OperatorCreationException;
+import org.bouncycastle.operator.jcajce.JcaContentSignerBuilder;
+
+import java.math.BigInteger;
+import java.security.KeyPair;
+import java.security.KeyPairGenerator;
+import java.security.NoSuchAlgorithmException;
+import java.security.NoSuchProviderException;
+import java.security.SecureRandom;
+import java.security.Security;
+import java.security.cert.CertificateException;
+import java.security.cert.X509Certificate;
+import java.util.Calendar;
+import java.util.Date;
+import java.util.concurrent.TimeUnit;
+
+/**
+ * SSL utils.
+ */
+@NoArgsConstructor(access = AccessLevel.PRIVATE)
+public final class SSLUtils {
+
+ static {
+ Security.addProvider(new BouncyCastleProvider());
+ }
+
+ @SneakyThrows({NoSuchProviderException.class, NoSuchAlgorithmException.class})
+ static KeyPair generateRSAKeyPair() {
+ KeyPairGenerator keyPairGenerator = KeyPairGenerator.getInstance("RSA", BouncyCastleProvider.PROVIDER_NAME);
+ keyPairGenerator.initialize(4096, new SecureRandom());
+ return keyPairGenerator.generateKeyPair();
+ }
+
+ @SneakyThrows({OperatorCreationException.class, CertificateException.class})
+ static X509Certificate generateSelfSignedX509Certificate(final KeyPair keyPair) {
+ long now = System.currentTimeMillis();
+ Date startDate = new Date(now - TimeUnit.DAYS.toMillis(1));
+ X500Name dnName = new X500NameBuilder(BCStyle.INSTANCE)
+ .addRDN(BCStyle.CN, "").addRDN(BCStyle.OU, "").addRDN(BCStyle.O, "").addRDN(BCStyle.L, "").addRDN(BCStyle.ST, "").addRDN(BCStyle.C, "").addRDN(BCStyle.E, "").build();
+ BigInteger certSerialNumber = new BigInteger(Long.toString(now));
+ Calendar endCalendar = Calendar.getInstance();
+ endCalendar.setTime(startDate);
+ endCalendar.add(Calendar.YEAR, 100);
+ ContentSigner contentSigner = new JcaContentSignerBuilder("SHA256WithRSA").build(keyPair.getPrivate());
+ X509v3CertificateBuilder builder = new JcaX509v3CertificateBuilder(dnName, certSerialNumber, startDate, endCalendar.getTime(), dnName, keyPair.getPublic());
+ return new JcaX509CertificateConverter().getCertificate(builder.build(contentSigner));
+ }
+}
diff --git a/proxy/frontend/core/src/test/java/org/apache/shardingsphere/proxy/frontend/ssl/ProxySSLContextTest.java b/proxy/frontend/core/src/test/java/org/apache/shardingsphere/proxy/frontend/ssl/ProxySSLContextTest.java
new file mode 100644
index 00000000000..579f58f134a
--- /dev/null
+++ b/proxy/frontend/core/src/test/java/org/apache/shardingsphere/proxy/frontend/ssl/ProxySSLContextTest.java
@@ -0,0 +1,137 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.shardingsphere.proxy.frontend.ssl;
+
+import io.netty.buffer.UnpooledByteBufAllocator;
+import io.netty.handler.ssl.SslContext;
+import io.netty.handler.ssl.SslContextBuilder;
+import lombok.SneakyThrows;
+import org.apache.shardingsphere.infra.config.props.ConfigurationPropertyKey;
+import org.apache.shardingsphere.mode.manager.ContextManager;
+import org.apache.shardingsphere.proxy.backend.context.ProxyContext;
+import org.apache.shardingsphere.test.mock.AutoMockExtension;
+import org.apache.shardingsphere.test.mock.StaticMockSettings;
+import org.junit.jupiter.api.AfterEach;
+import org.junit.jupiter.api.BeforeEach;
+import org.junit.jupiter.api.Test;
+import org.junit.jupiter.api.extension.ExtendWith;
+import org.mockito.MockedStatic;
+import org.mockito.internal.configuration.plugins.Plugins;
+
+import javax.net.ssl.SSLEngine;
+import javax.net.ssl.SSLException;
+import java.io.File;
+import java.security.PrivateKey;
+import java.security.cert.X509Certificate;
+import java.util.Arrays;
+
+import static org.hamcrest.CoreMatchers.is;
+import static org.hamcrest.MatcherAssert.assertThat;
+import static org.junit.jupiter.api.Assertions.assertNull;
+import static org.junit.jupiter.api.Assertions.assertThrows;
+import static org.junit.jupiter.api.Assertions.assertTrue;
+import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.Mockito.RETURNS_DEEP_STUBS;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.mockStatic;
+import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.when;
+
+@ExtendWith(AutoMockExtension.class)
+@StaticMockSettings(ProxyContext.class)
+class ProxySSLContextTest {
+
+ @BeforeEach
+ void setup() throws NoSuchFieldException, IllegalAccessException {
+ Plugins.getMemberAccessor().set(ProxySSLContext.class.getDeclaredField("sslContext"), ProxySSLContext.getInstance(), null);
+ when(ProxyContext.getInstance().getContextManager()).thenReturn(mock(ContextManager.class, RETURNS_DEEP_STUBS));
+ }
+
+ @Test
+ void assertInitWithSSLNotEnabled() throws SSLException {
+ when(ProxyContext.getInstance().getContextManager().getMetaDataContexts().getMetaData().getProps().<Boolean>getValue(ConfigurationPropertyKey.PROXY_FRONTEND_SSL_ENABLED)).thenReturn(false);
+ ProxySSLContext.init();
+ assertNull(getSslContext());
+ }
+
+ @Test
+ void assertInitWithIllegalConfig() {
+ when(ProxyContext.getInstance().getContextManager().getMetaDataContexts().getMetaData().getProps().<Boolean>getValue(ConfigurationPropertyKey.PROXY_FRONTEND_SSL_ENABLED)).thenReturn(true);
+ when(ProxyContext.getInstance().getContextManager().getMetaDataContexts().getMetaData().getProps().<String>getValue(ConfigurationPropertyKey.PROXY_FRONTEND_SSL_CERT_FILE)).thenReturn("");
+ when(ProxyContext.getInstance().getContextManager().getMetaDataContexts().getMetaData().getProps().<String>getValue(ConfigurationPropertyKey.PROXY_FRONTEND_SSL_KEY_FILE)).thenReturn("key");
+ assertThrows(IllegalArgumentException.class, ProxySSLContext::init);
+ }
+
+ @Test
+ void assertInitWithUserProvidedCertificate() throws SSLException {
+ when(ProxyContext.getInstance().getContextManager().getMetaDataContexts().getMetaData().getProps().<Boolean>getValue(ConfigurationPropertyKey.PROXY_FRONTEND_SSL_ENABLED)).thenReturn(true);
+ when(ProxyContext.getInstance().getContextManager().getMetaDataContexts().getMetaData().getProps().<String>getValue(ConfigurationPropertyKey.PROXY_FRONTEND_SSL_CERT_FILE)).thenReturn("cert");
+ when(ProxyContext.getInstance().getContextManager().getMetaDataContexts().getMetaData().getProps().<String>getValue(ConfigurationPropertyKey.PROXY_FRONTEND_SSL_KEY_FILE)).thenReturn("key");
+ when(ProxyContext.getInstance().getContextManager().getMetaDataContexts().getMetaData().getProps().<String>getValue(ConfigurationPropertyKey.PROXY_FRONTEND_SSL_VERSION))
+ .thenReturn("TLSv1.2,TLSv1.3");
+ when(ProxyContext.getInstance().getContextManager().getMetaDataContexts().getMetaData().getProps().<String>getValue(ConfigurationPropertyKey.PROXY_FRONTEND_SSL_CIPHER))
+ .thenReturn("CIPHER1,CIPHER2");
+ SslContextBuilder builder = mock(SslContextBuilder.class);
+ SslContext expectedSslContext = mock(SslContext.class);
+ when(builder.build()).thenReturn(expectedSslContext);
+ SSLEngine expectedSSLEngine = mock(SSLEngine.class);
+ when(expectedSslContext.newEngine(UnpooledByteBufAllocator.DEFAULT)).thenReturn(expectedSSLEngine);
+ try (MockedStatic<SslContextBuilder> mockedStatic = mockStatic(SslContextBuilder.class)) {
+ mockedStatic.when(() -> SslContextBuilder.forServer(any(File.class), any(File.class))).thenReturn(builder);
+ ProxySSLContext.init();
+ }
+ verify(builder).protocols("TLSv1.2", "TLSv1.3");
+ verify(builder).ciphers(Arrays.asList("CIPHER1", "CIPHER2"));
+ assertThat(getSslContext(), is(expectedSslContext));
+ assertTrue(ProxySSLContext.getInstance().isSSLEnabled());
+ assertThat(ProxySSLContext.getInstance().newSSLEngine(UnpooledByteBufAllocator.DEFAULT), is(expectedSSLEngine));
+ }
+
+ @Test
+ void assertInitWithGeneratedSelfSignedCertificate() throws SSLException {
+ when(ProxyContext.getInstance().getContextManager().getMetaDataContexts().getMetaData().getProps().<Boolean>getValue(ConfigurationPropertyKey.PROXY_FRONTEND_SSL_ENABLED)).thenReturn(true);
+ when(ProxyContext.getInstance().getContextManager().getMetaDataContexts().getMetaData().getProps().<String>getValue(ConfigurationPropertyKey.PROXY_FRONTEND_SSL_CERT_FILE)).thenReturn("");
+ when(ProxyContext.getInstance().getContextManager().getMetaDataContexts().getMetaData().getProps().<String>getValue(ConfigurationPropertyKey.PROXY_FRONTEND_SSL_KEY_FILE)).thenReturn("");
+ when(ProxyContext.getInstance().getContextManager().getMetaDataContexts().getMetaData().getProps().<String>getValue(ConfigurationPropertyKey.PROXY_FRONTEND_SSL_VERSION))
+ .thenReturn("TLSv1.2,TLSv1.3");
+ when(ProxyContext.getInstance().getContextManager().getMetaDataContexts().getMetaData().getProps().<String>getValue(ConfigurationPropertyKey.PROXY_FRONTEND_SSL_CIPHER))
+ .thenReturn("CIPHER1,CIPHER2");
+ SslContextBuilder builder = mock(SslContextBuilder.class);
+ SslContext expectedSslContext = mock(SslContext.class);
+ when(builder.build()).thenReturn(expectedSslContext);
+ SSLEngine expectedSSLEngine = mock(SSLEngine.class);
+ when(expectedSslContext.newEngine(UnpooledByteBufAllocator.DEFAULT)).thenReturn(expectedSSLEngine);
+ try (MockedStatic<SslContextBuilder> mockedStatic = mockStatic(SslContextBuilder.class)) {
+ mockedStatic.when(() -> SslContextBuilder.forServer(any(PrivateKey.class), any(X509Certificate.class))).thenReturn(builder);
+ ProxySSLContext.init();
+ }
+ assertThat(getSslContext(), is(expectedSslContext));
+ assertTrue(ProxySSLContext.getInstance().isSSLEnabled());
+ assertThat(ProxySSLContext.getInstance().newSSLEngine(UnpooledByteBufAllocator.DEFAULT), is(expectedSSLEngine));
+ }
+
+ @SneakyThrows({NoSuchFieldException.class, IllegalAccessException.class})
+ private SslContext getSslContext() {
+ return (SslContext) Plugins.getMemberAccessor().get(ProxySSLContext.class.getDeclaredField("sslContext"), ProxySSLContext.getInstance());
+ }
+
+ @AfterEach
+ void tearDown() throws NoSuchFieldException, IllegalAccessException {
+ Plugins.getMemberAccessor().set(ProxySSLContext.class.getDeclaredField("sslContext"), ProxySSLContext.getInstance(), null);
+ }
+}
diff --git a/proxy/frontend/core/src/test/java/org/apache/shardingsphere/proxy/frontend/ssl/SSLUtilsTest.java b/proxy/frontend/core/src/test/java/org/apache/shardingsphere/proxy/frontend/ssl/SSLUtilsTest.java
new file mode 100644
index 00000000000..b2bf071db98
--- /dev/null
+++ b/proxy/frontend/core/src/test/java/org/apache/shardingsphere/proxy/frontend/ssl/SSLUtilsTest.java
@@ -0,0 +1,70 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.shardingsphere.proxy.frontend.ssl;
+
+import org.junit.jupiter.api.Test;
+
+import java.security.InvalidKeyException;
+import java.security.KeyPair;
+import java.security.NoSuchAlgorithmException;
+import java.security.NoSuchProviderException;
+import java.security.SignatureException;
+import java.security.cert.CertificateException;
+import java.security.cert.CertificateExpiredException;
+import java.security.cert.CertificateNotYetValidException;
+import java.security.cert.X509Certificate;
+import java.util.Calendar;
+import java.util.Date;
+
+import static org.hamcrest.CoreMatchers.is;
+import static org.hamcrest.MatcherAssert.assertThat;
+import static org.junit.jupiter.api.Assertions.assertThrows;
+import static org.junit.jupiter.api.Assertions.fail;
+
+class SSLUtilsTest {
+
+ @Test
+ void assertGenerateKeyPair() {
+ KeyPair actual = SSLUtils.generateRSAKeyPair();
+ assertThat(actual.getPrivate().getAlgorithm(), is("RSA"));
+ assertThat(actual.getPrivate().getFormat(), is("PKCS#8"));
+ assertThat(actual.getPublic().getAlgorithm(), is("RSA"));
+ assertThat(actual.getPublic().getFormat(), is("X.509"));
+ }
+
+ @Test
+ void assertGenerateSelfSignedX509Certificate() {
+ KeyPair keyPair = SSLUtils.generateRSAKeyPair();
+ X509Certificate actual = SSLUtils.generateSelfSignedX509Certificate(keyPair);
+ Calendar calendar = Calendar.getInstance();
+ calendar.setTime(new Date());
+ calendar.add(Calendar.YEAR, 99);
+ try {
+ actual.checkValidity(new Date());
+ actual.checkValidity(calendar.getTime());
+ } catch (CertificateExpiredException | CertificateNotYetValidException ex) {
+ fail(ex);
+ }
+ try {
+ actual.verify(keyPair.getPublic());
+ } catch (CertificateException | NoSuchAlgorithmException | InvalidKeyException | NoSuchProviderException | SignatureException ex) {
+ fail(ex);
+ }
+ assertThrows(SignatureException.class, () -> actual.verify(SSLUtils.generateRSAKeyPair().getPublic()));
+ }
+}
diff --git a/proxy/frontend/type/mysql/src/main/java/org/apache/shardingsphere/proxy/frontend/mysql/authentication/MySQLAuthenticationEngine.java b/proxy/frontend/type/mysql/src/main/java/org/apache/shardingsphere/proxy/frontend/mysql/authentication/MySQLAuthenticationEngine.java
index 08bf7de52cf..398589a2d70 100644
--- a/proxy/frontend/type/mysql/src/main/java/org/apache/shardingsphere/proxy/frontend/mysql/authentication/MySQLAuthenticationEngine.java
+++ b/proxy/frontend/type/mysql/src/main/java/org/apache/shardingsphere/proxy/frontend/mysql/authentication/MySQLAuthenticationEngine.java
@@ -53,6 +53,8 @@ import org.apache.shardingsphere.proxy.frontend.authentication.AuthenticatorFact
import org.apache.shardingsphere.proxy.frontend.connection.ConnectionIdGenerator;
import org.apache.shardingsphere.proxy.frontend.mysql.authentication.authenticator.MySQLAuthenticatorType;
import org.apache.shardingsphere.proxy.frontend.mysql.command.query.binary.MySQLStatementIDGenerator;
+import org.apache.shardingsphere.proxy.frontend.mysql.ssl.MySQLSSLRequestHandler;
+import org.apache.shardingsphere.proxy.frontend.ssl.ProxySSLContext;
import java.net.InetSocketAddress;
import java.net.SocketAddress;
@@ -76,7 +78,11 @@ public final class MySQLAuthenticationEngine implements AuthenticationEngine {
public int handshake(final ChannelHandlerContext context) {
int result = ConnectionIdGenerator.getInstance().nextId();
connectionPhase = MySQLConnectionPhase.AUTH_PHASE_FAST_PATH;
- context.writeAndFlush(new MySQLHandshakePacket(result, authPluginData));
+ boolean sslEnabled = ProxySSLContext.getInstance().isSSLEnabled();
+ if (sslEnabled) {
+ context.pipeline().addFirst(MySQLSSLRequestHandler.class.getSimpleName(), new MySQLSSLRequestHandler());
+ }
+ context.writeAndFlush(new MySQLHandshakePacket(result, sslEnabled, authPluginData));
MySQLStatementIDGenerator.getInstance().registerConnection(result);
return result;
}
diff --git a/proxy/frontend/type/mysql/src/main/java/org/apache/shardingsphere/proxy/frontend/mysql/ssl/MySQLSSLRequestHandler.java b/proxy/frontend/type/mysql/src/main/java/org/apache/shardingsphere/proxy/frontend/mysql/ssl/MySQLSSLRequestHandler.java
new file mode 100644
index 00000000000..f6f2192d50c
--- /dev/null
+++ b/proxy/frontend/type/mysql/src/main/java/org/apache/shardingsphere/proxy/frontend/mysql/ssl/MySQLSSLRequestHandler.java
@@ -0,0 +1,60 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.shardingsphere.proxy.frontend.mysql.ssl;
+
+import io.netty.buffer.ByteBuf;
+import io.netty.channel.ChannelHandlerContext;
+import io.netty.handler.codec.ByteToMessageDecoder;
+import io.netty.handler.ssl.SslHandler;
+import org.apache.shardingsphere.db.protocol.mysql.constant.MySQLCapabilityFlag;
+import org.apache.shardingsphere.proxy.frontend.ssl.ProxySSLContext;
+
+import java.util.List;
+
+/**
+ * MySQL SSL request handler.
+ */
+public final class MySQLSSLRequestHandler extends ByteToMessageDecoder {
+
+ private static final int HEADER_LENGTH = 4;
+
+ private static final int SSL_REQUEST_LENGTH = 32;
+
+ public MySQLSSLRequestHandler() {
+ setSingleDecode(true);
+ }
+
+ @Override
+ protected void decode(final ChannelHandlerContext context, final ByteBuf in, final List<Object> out) {
+ if (in.readableBytes() < HEADER_LENGTH || in.readableBytes() < HEADER_LENGTH + in.getUnsignedMediumLE(in.readerIndex())) {
+ return;
+ }
+ if (isSSLRequest(in)) {
+ SslHandler sslHandler = new SslHandler(ProxySSLContext.getInstance().newSSLEngine(context.alloc()));
+ context.pipeline().addAfter(MySQLSSLRequestHandler.class.getSimpleName(), SslHandler.class.getSimpleName(), sslHandler);
+ in.skipBytes(HEADER_LENGTH + SSL_REQUEST_LENGTH);
+ }
+ context.pipeline().remove(this);
+ }
+
+ private boolean isSSLRequest(final ByteBuf in) {
+ int clientCapabilitiesFlagOffset = HEADER_LENGTH + in.readerIndex();
+ return SSL_REQUEST_LENGTH == in.getUnsignedMediumLE(in.readerIndex())
+ && MySQLCapabilityFlag.CLIENT_SSL.getValue() == (MySQLCapabilityFlag.CLIENT_SSL.getValue() & in.getUnsignedShortLE(clientCapabilitiesFlagOffset));
+ }
+}
diff --git a/proxy/frontend/type/mysql/src/test/java/org/apache/shardingsphere/proxy/frontend/mysql/authentication/MySQLAuthenticationEngineTest.java b/proxy/frontend/type/mysql/src/test/java/org/apache/shardingsphere/proxy/frontend/mysql/authentication/MySQLAuthenticationEngineTest.java
index 89ec69ac598..aa303b60efb 100644
--- a/proxy/frontend/type/mysql/src/test/java/org/apache/shardingsphere/proxy/frontend/mysql/authentication/MySQLAuthenticationEngineTest.java
+++ b/proxy/frontend/type/mysql/src/test/java/org/apache/shardingsphere/proxy/frontend/mysql/authentication/MySQLAuthenticationEngineTest.java
@@ -20,6 +20,7 @@ package org.apache.shardingsphere.proxy.frontend.mysql.authentication;
import io.netty.buffer.Unpooled;
import io.netty.channel.Channel;
import io.netty.channel.ChannelHandlerContext;
+import io.netty.channel.ChannelPipeline;
import io.netty.util.Attribute;
import lombok.SneakyThrows;
import org.apache.shardingsphere.authority.provider.simple.model.privilege.AllPrivilegesPermittedShardingSpherePrivileges;
@@ -50,6 +51,8 @@ import org.apache.shardingsphere.proxy.backend.context.ProxyContext;
import org.apache.shardingsphere.proxy.frontend.authentication.AuthenticationResultBuilder;
import org.apache.shardingsphere.proxy.frontend.authentication.Authenticator;
import org.apache.shardingsphere.proxy.frontend.authentication.AuthenticatorFactory;
+import org.apache.shardingsphere.proxy.frontend.mysql.ssl.MySQLSSLRequestHandler;
+import org.apache.shardingsphere.proxy.frontend.ssl.ProxySSLContext;
import org.apache.shardingsphere.test.mock.AutoMockExtension;
import org.apache.shardingsphere.test.mock.StaticMockSettings;
import org.junit.jupiter.api.Test;
@@ -74,6 +77,7 @@ import static org.hamcrest.MatcherAssert.assertThat;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.RETURNS_DEEP_STUBS;
import static org.mockito.Mockito.doReturn;
import static org.mockito.Mockito.mock;
@@ -82,7 +86,7 @@ import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
@ExtendWith(AutoMockExtension.class)
-@StaticMockSettings(ProxyContext.class)
+@StaticMockSettings({ProxyContext.class, ProxySSLContext.class})
@MockitoSettings(strictness = Strictness.LENIENT)
class MySQLAuthenticationEngineTest {
@@ -91,12 +95,22 @@ class MySQLAuthenticationEngineTest {
private final byte[] authResponse = {-27, 89, -20, -27, 65, -120, -64, -101, 86, -100, -108, -100, 6, -125, -37, 117, 14, -43, 95, -113};
@Test
- void assertHandshake() {
+ void assertHandshakeWithSSLNotEnabled() {
ChannelHandlerContext context = mockChannelHandlerContext();
assertTrue(authenticationEngine.handshake(context) > 0);
verify(context).writeAndFlush(any(MySQLHandshakePacket.class));
}
+ @Test
+ void assertHandshakeWithSSLEnabled() {
+ when(ProxySSLContext.getInstance().isSSLEnabled()).thenReturn(true);
+ ChannelHandlerContext context = mockChannelHandlerContext();
+ when(context.pipeline()).thenReturn(mock(ChannelPipeline.class));
+ assertTrue(authenticationEngine.handshake(context) > 0);
+ verify(context.pipeline()).addFirst(eq(MySQLSSLRequestHandler.class.getSimpleName()), any(MySQLSSLRequestHandler.class));
+ verify(context).writeAndFlush(any(MySQLHandshakePacket.class));
+ }
+
@Test
void assertBadHandshakeReceived() {
AuthorityRule rule = mock(AuthorityRule.class);
diff --git a/proxy/frontend/type/mysql/src/test/java/org/apache/shardingsphere/proxy/frontend/mysql/ssl/MySQLSSLRequestHandlerTest.java b/proxy/frontend/type/mysql/src/test/java/org/apache/shardingsphere/proxy/frontend/mysql/ssl/MySQLSSLRequestHandlerTest.java
new file mode 100644
index 00000000000..467cd7a9eb8
--- /dev/null
+++ b/proxy/frontend/type/mysql/src/test/java/org/apache/shardingsphere/proxy/frontend/mysql/ssl/MySQLSSLRequestHandlerTest.java
@@ -0,0 +1,77 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.shardingsphere.proxy.frontend.mysql.ssl;
+
+import io.netty.buffer.ByteBuf;
+import io.netty.buffer.ByteBufUtil;
+import io.netty.buffer.Unpooled;
+import io.netty.channel.ChannelHandlerContext;
+import io.netty.channel.embedded.EmbeddedChannel;
+import io.netty.handler.ssl.SslHandler;
+import io.netty.util.internal.StringUtil;
+import org.apache.shardingsphere.proxy.frontend.ssl.ProxySSLContext;
+import org.apache.shardingsphere.test.mock.AutoMockExtension;
+import org.apache.shardingsphere.test.mock.StaticMockSettings;
+import org.junit.jupiter.api.Test;
+import org.junit.jupiter.api.extension.ExtendWith;
+import org.mockito.MockedConstruction;
+
+import static org.hamcrest.CoreMatchers.is;
+import static org.hamcrest.MatcherAssert.assertThat;
+import static org.junit.jupiter.api.Assertions.assertNull;
+import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.ArgumentMatchers.argThat;
+import static org.mockito.Mockito.mockConstruction;
+import static org.mockito.Mockito.verify;
+
+@ExtendWith(AutoMockExtension.class)
+@StaticMockSettings(ProxySSLContext.class)
+class MySQLSSLRequestHandlerTest {
+
+ private static final byte[] MYSQL_SSL_REQUEST = StringUtil.decodeHexDump("2000000185aeff1900000001080000000000000000000000000000000000000000000000");
+
+ private static final byte[] FAKE_TLS_HANDSHAKE = StringUtil.decodeHexDump("1603010000");
+
+ private static final byte[] MYSQL_NON_SSL_REQUEST = StringUtil.decodeHexDump("2000000185a6ff1900000001080000000000000000000000000000000000000000000000");
+
+ @Test
+ void assertReceiveSSLRequest() throws Exception {
+ EmbeddedChannel channel = new EmbeddedChannel();
+ channel.pipeline().addFirst(MySQLSSLRequestHandler.class.getSimpleName(), new MySQLSSLRequestHandler());
+ try (MockedConstruction<SslHandler> mockedConstruction = mockConstruction(SslHandler.class)) {
+ channel.writeInbound(Unpooled.wrappedBuffer(MYSQL_SSL_REQUEST), Unpooled.wrappedBuffer(FAKE_TLS_HANDSHAKE));
+ verify(mockedConstruction.constructed().get(0)).channelRead(any(ChannelHandlerContext.class), argThat(this::assertTLSHandshakeByteBuf));
+ }
+ assertNull(channel.pipeline().get(MySQLSSLRequestHandler.class));
+ }
+
+ private boolean assertTLSHandshakeByteBuf(final Object actual) {
+ assertThat(ByteBufUtil.getBytes((ByteBuf) actual), is(FAKE_TLS_HANDSHAKE));
+ return true;
+ }
+
+ @Test
+ void assertReceiveHandshakeResponse() {
+ EmbeddedChannel channel = new EmbeddedChannel();
+ channel.pipeline().addFirst(MySQLSSLRequestHandler.class.getSimpleName(), new MySQLSSLRequestHandler());
+ channel.writeInbound(Unpooled.wrappedBuffer(MYSQL_NON_SSL_REQUEST));
+ ByteBuf actual = channel.readInbound();
+ assertThat(ByteBufUtil.getBytes(actual), is(MYSQL_NON_SSL_REQUEST));
+ assertNull(channel.pipeline().get(MySQLSSLRequestHandler.class));
+ }
+}
diff --git a/proxy/frontend/type/opengauss/src/main/java/org/apache/shardingsphere/proxy/frontend/opengauss/authentication/OpenGaussAuthenticationEngine.java b/proxy/frontend/type/opengauss/src/main/java/org/apache/shardingsphere/proxy/frontend/opengauss/authentication/OpenGaussAuthenticationEngine.java
index 84e9ca2a0ec..9c346c1a311 100644
--- a/proxy/frontend/type/opengauss/src/main/java/org/apache/shardingsphere/proxy/frontend/opengauss/authentication/OpenGaussAuthenticationEngine.java
+++ b/proxy/frontend/type/opengauss/src/main/java/org/apache/shardingsphere/proxy/frontend/opengauss/authentication/OpenGaussAuthenticationEngine.java
@@ -19,6 +19,7 @@ package org.apache.shardingsphere.proxy.frontend.opengauss.authentication;
import com.google.common.base.Strings;
import io.netty.channel.ChannelHandlerContext;
+import io.netty.handler.ssl.SslHandler;
import org.apache.shardingsphere.authority.checker.AuthorityChecker;
import org.apache.shardingsphere.authority.rule.AuthorityRule;
import org.apache.shardingsphere.db.protocol.constant.CommonConstants;
@@ -34,7 +35,8 @@ import org.apache.shardingsphere.db.protocol.postgresql.packet.handshake.Postgre
import org.apache.shardingsphere.db.protocol.postgresql.packet.handshake.PostgreSQLParameterStatusPacket;
import org.apache.shardingsphere.db.protocol.postgresql.packet.handshake.PostgreSQLPasswordMessagePacket;
import org.apache.shardingsphere.db.protocol.postgresql.packet.handshake.PostgreSQLRandomGenerator;
-import org.apache.shardingsphere.db.protocol.postgresql.packet.handshake.PostgreSQLSSLNegativePacket;
+import org.apache.shardingsphere.db.protocol.postgresql.packet.handshake.PostgreSQLSSLUnwillingPacket;
+import org.apache.shardingsphere.db.protocol.postgresql.packet.handshake.PostgreSQLSSLWillingPacket;
import org.apache.shardingsphere.db.protocol.postgresql.packet.handshake.authentication.PostgreSQLMD5PasswordAuthenticationPacket;
import org.apache.shardingsphere.db.protocol.postgresql.packet.identifier.PostgreSQLIdentifierPacket;
import org.apache.shardingsphere.db.protocol.postgresql.packet.identifier.PostgreSQLMessagePacketType;
@@ -57,6 +59,7 @@ import org.apache.shardingsphere.proxy.frontend.authentication.Authenticator;
import org.apache.shardingsphere.proxy.frontend.authentication.AuthenticatorFactory;
import org.apache.shardingsphere.proxy.frontend.connection.ConnectionIdGenerator;
import org.apache.shardingsphere.proxy.frontend.opengauss.authentication.authenticator.OpenGaussAuthenticatorType;
+import org.apache.shardingsphere.proxy.frontend.ssl.ProxySSLContext;
import java.util.Optional;
@@ -95,7 +98,13 @@ public final class OpenGaussAuthenticationEngine implements AuthenticationEngine
@Override
public AuthenticationResult authenticate(final ChannelHandlerContext context, final PacketPayload payload) {
if (SSL_REQUEST_PAYLOAD_LENGTH == payload.getByteBuf().markReaderIndex().readInt() && SSL_REQUEST_CODE == payload.getByteBuf().readInt()) {
- context.writeAndFlush(new PostgreSQLSSLNegativePacket());
+ if (ProxySSLContext.getInstance().isSSLEnabled()) {
+ SslHandler sslHandler = new SslHandler(ProxySSLContext.getInstance().newSSLEngine(context.alloc()), true);
+ context.pipeline().addFirst(SslHandler.class.getSimpleName(), sslHandler);
+ context.writeAndFlush(new PostgreSQLSSLWillingPacket());
+ } else {
+ context.writeAndFlush(new PostgreSQLSSLUnwillingPacket());
+ }
return AuthenticationResultBuilder.continued();
}
payload.getByteBuf().resetReaderIndex();
diff --git a/proxy/frontend/type/opengauss/src/test/java/org/apache/shardingsphere/proxy/frontend/opengauss/authentication/OpenGaussAuthenticationEngineTest.java b/proxy/frontend/type/opengauss/src/test/java/org/apache/shardingsphere/proxy/frontend/opengauss/authentication/OpenGaussAuthenticationEngineTest.java
index 08f6ac1f812..c84ab540c04 100644
--- a/proxy/frontend/type/opengauss/src/test/java/org/apache/shardingsphere/proxy/frontend/opengauss/authentication/OpenGaussAuthenticationEngineTest.java
+++ b/proxy/frontend/type/opengauss/src/test/java/org/apache/shardingsphere/proxy/frontend/opengauss/authentication/OpenGaussAuthenticationEngineTest.java
@@ -21,12 +21,15 @@ import io.netty.buffer.ByteBuf;
import io.netty.buffer.UnpooledByteBufAllocator;
import io.netty.buffer.UnpooledHeapByteBuf;
import io.netty.channel.ChannelHandlerContext;
+import io.netty.handler.ssl.SslHandler;
import io.netty.util.Attribute;
import lombok.SneakyThrows;
import org.apache.shardingsphere.authority.config.AuthorityRuleConfiguration;
import org.apache.shardingsphere.authority.rule.builder.AuthorityRuleBuilder;
import org.apache.shardingsphere.db.protocol.constant.CommonConstants;
import org.apache.shardingsphere.db.protocol.payload.PacketPayload;
+import org.apache.shardingsphere.db.protocol.postgresql.packet.handshake.PostgreSQLSSLUnwillingPacket;
+import org.apache.shardingsphere.db.protocol.postgresql.packet.handshake.PostgreSQLSSLWillingPacket;
import org.apache.shardingsphere.db.protocol.postgresql.payload.PostgreSQLPacketPayload;
import org.apache.shardingsphere.dialect.postgresql.exception.authority.EmptyUsernameException;
import org.apache.shardingsphere.dialect.postgresql.exception.protocol.ProtocolViolationException;
@@ -35,11 +38,12 @@ import org.apache.shardingsphere.infra.config.props.ConfigurationProperties;
import org.apache.shardingsphere.infra.metadata.ShardingSphereMetaData;
import org.apache.shardingsphere.infra.metadata.database.rule.ShardingSphereRuleMetaData;
import org.apache.shardingsphere.infra.metadata.user.ShardingSphereUser;
+import org.apache.shardingsphere.metadata.persist.MetaDataPersistService;
import org.apache.shardingsphere.mode.manager.ContextManager;
import org.apache.shardingsphere.mode.metadata.MetaDataContexts;
-import org.apache.shardingsphere.metadata.persist.MetaDataPersistService;
import org.apache.shardingsphere.proxy.backend.context.ProxyContext;
import org.apache.shardingsphere.proxy.frontend.authentication.AuthenticationResult;
+import org.apache.shardingsphere.proxy.frontend.ssl.ProxySSLContext;
import org.apache.shardingsphere.test.mock.AutoMockExtension;
import org.apache.shardingsphere.test.mock.StaticMockSettings;
import org.junit.jupiter.api.BeforeEach;
@@ -53,14 +57,17 @@ import java.nio.charset.StandardCharsets;
import java.util.Collections;
import java.util.Properties;
-import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertFalse;
+import static org.junit.jupiter.api.Assertions.assertThrows;
+import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.RETURNS_DEEP_STUBS;
import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
@ExtendWith(AutoMockExtension.class)
-@StaticMockSettings(ProxyContext.class)
+@StaticMockSettings({ProxyContext.class, ProxySSLContext.class})
class OpenGaussAuthenticationEngineTest {
@Mock(answer = Answers.RETURNS_DEEP_STUBS)
@@ -78,12 +85,28 @@ class OpenGaussAuthenticationEngineTest {
}
@Test
- void assertSSLNegative() {
+ void assertSSLUnwilling() {
+ ByteBuf byteBuf = createByteBuf(8, 8);
+ byteBuf.writeInt(8);
+ byteBuf.writeInt(80877103);
+ PacketPayload payload = new PostgreSQLPacketPayload(byteBuf, StandardCharsets.UTF_8);
+ ChannelHandlerContext context = mock(ChannelHandlerContext.class);
+ AuthenticationResult actual = new OpenGaussAuthenticationEngine().authenticate(context, payload);
+ verify(context).writeAndFlush(any(PostgreSQLSSLUnwillingPacket.class));
+ assertFalse(actual.isFinished());
+ }
+
+ @Test
+ void assertSSLWilling() {
ByteBuf byteBuf = createByteBuf(8, 8);
byteBuf.writeInt(8);
byteBuf.writeInt(80877103);
PacketPayload payload = new PostgreSQLPacketPayload(byteBuf, StandardCharsets.UTF_8);
- AuthenticationResult actual = new OpenGaussAuthenticationEngine().authenticate(mock(ChannelHandlerContext.class), payload);
+ ChannelHandlerContext context = mock(ChannelHandlerContext.class, RETURNS_DEEP_STUBS);
+ when(ProxySSLContext.getInstance().isSSLEnabled()).thenReturn(true);
+ AuthenticationResult actual = new OpenGaussAuthenticationEngine().authenticate(context, payload);
+ verify(context).writeAndFlush(any(PostgreSQLSSLWillingPacket.class));
+ verify(context.pipeline()).addFirst(eq(SslHandler.class.getSimpleName()), any(SslHandler.class));
assertFalse(actual.isFinished());
}
diff --git a/proxy/frontend/type/postgresql/src/main/java/org/apache/shardingsphere/proxy/frontend/postgresql/authentication/PostgreSQLAuthenticationEngine.java b/proxy/frontend/type/postgresql/src/main/java/org/apache/shardingsphere/proxy/frontend/postgresql/authentication/PostgreSQLAuthenticationEngine.java
index 35615b92600..6651058cebf 100644
--- a/proxy/frontend/type/postgresql/src/main/java/org/apache/shardingsphere/proxy/frontend/postgresql/authentication/PostgreSQLAuthenticationEngine.java
+++ b/proxy/frontend/type/postgresql/src/main/java/org/apache/shardingsphere/proxy/frontend/postgresql/authentication/PostgreSQLAuthenticationEngine.java
@@ -19,6 +19,7 @@ package org.apache.shardingsphere.proxy.frontend.postgresql.authentication;
import com.google.common.base.Strings;
import io.netty.channel.ChannelHandlerContext;
+import io.netty.handler.ssl.SslHandler;
import org.apache.shardingsphere.authority.checker.AuthorityChecker;
import org.apache.shardingsphere.authority.rule.AuthorityRule;
import org.apache.shardingsphere.db.protocol.constant.CommonConstants;
@@ -31,7 +32,8 @@ import org.apache.shardingsphere.db.protocol.postgresql.packet.handshake.Postgre
import org.apache.shardingsphere.db.protocol.postgresql.packet.handshake.PostgreSQLParameterStatusPacket;
import org.apache.shardingsphere.db.protocol.postgresql.packet.handshake.PostgreSQLPasswordMessagePacket;
import org.apache.shardingsphere.db.protocol.postgresql.packet.handshake.PostgreSQLRandomGenerator;
-import org.apache.shardingsphere.db.protocol.postgresql.packet.handshake.PostgreSQLSSLNegativePacket;
+import org.apache.shardingsphere.db.protocol.postgresql.packet.handshake.PostgreSQLSSLUnwillingPacket;
+import org.apache.shardingsphere.db.protocol.postgresql.packet.handshake.PostgreSQLSSLWillingPacket;
import org.apache.shardingsphere.db.protocol.postgresql.packet.handshake.authentication.PostgreSQLMD5PasswordAuthenticationPacket;
import org.apache.shardingsphere.db.protocol.postgresql.packet.handshake.authentication.PostgreSQLPasswordAuthenticationPacket;
import org.apache.shardingsphere.db.protocol.postgresql.packet.identifier.PostgreSQLIdentifierPacket;
@@ -55,6 +57,7 @@ import org.apache.shardingsphere.proxy.frontend.authentication.Authenticator;
import org.apache.shardingsphere.proxy.frontend.authentication.AuthenticatorFactory;
import org.apache.shardingsphere.proxy.frontend.connection.ConnectionIdGenerator;
import org.apache.shardingsphere.proxy.frontend.postgresql.authentication.authenticator.PostgreSQLAuthenticatorType;
+import org.apache.shardingsphere.proxy.frontend.ssl.ProxySSLContext;
import java.util.Optional;
@@ -83,7 +86,13 @@ public final class PostgreSQLAuthenticationEngine implements AuthenticationEngin
@Override
public AuthenticationResult authenticate(final ChannelHandlerContext context, final PacketPayload payload) {
if (SSL_REQUEST_PAYLOAD_LENGTH == payload.getByteBuf().markReaderIndex().readInt() && SSL_REQUEST_CODE == payload.getByteBuf().readInt()) {
- context.writeAndFlush(new PostgreSQLSSLNegativePacket());
+ if (ProxySSLContext.getInstance().isSSLEnabled()) {
+ SslHandler sslHandler = new SslHandler(ProxySSLContext.getInstance().newSSLEngine(context.alloc()), true);
+ context.pipeline().addFirst(SslHandler.class.getSimpleName(), sslHandler);
+ context.writeAndFlush(new PostgreSQLSSLWillingPacket());
+ } else {
+ context.writeAndFlush(new PostgreSQLSSLUnwillingPacket());
+ }
return AuthenticationResultBuilder.continued();
}
payload.getByteBuf().resetReaderIndex();
diff --git a/proxy/frontend/type/postgresql/src/test/java/org/apache/shardingsphere/proxy/frontend/postgresql/authentication/PostgreSQLAuthenticationEngineTest.java b/proxy/frontend/type/postgresql/src/test/java/org/apache/shardingsphere/proxy/frontend/postgresql/authentication/PostgreSQLAuthenticationEngineTest.java
index 294bb19ab71..ccfdd824200 100644
--- a/proxy/frontend/type/postgresql/src/test/java/org/apache/shardingsphere/proxy/frontend/postgresql/authentication/PostgreSQLAuthenticationEngineTest.java
+++ b/proxy/frontend/type/postgresql/src/test/java/org/apache/shardingsphere/proxy/frontend/postgresql/authentication/PostgreSQLAuthenticationEngineTest.java
@@ -21,6 +21,7 @@ import io.netty.buffer.ByteBuf;
import io.netty.buffer.UnpooledByteBufAllocator;
import io.netty.buffer.UnpooledHeapByteBuf;
import io.netty.channel.ChannelHandlerContext;
+import io.netty.handler.ssl.SslHandler;
import io.netty.util.Attribute;
import lombok.SneakyThrows;
import org.apache.shardingsphere.authority.config.AuthorityRuleConfiguration;
@@ -28,6 +29,8 @@ import org.apache.shardingsphere.authority.rule.AuthorityRule;
import org.apache.shardingsphere.authority.rule.builder.AuthorityRuleBuilder;
import org.apache.shardingsphere.db.protocol.constant.CommonConstants;
import org.apache.shardingsphere.db.protocol.payload.PacketPayload;
+import org.apache.shardingsphere.db.protocol.postgresql.packet.handshake.PostgreSQLSSLUnwillingPacket;
+import org.apache.shardingsphere.db.protocol.postgresql.packet.handshake.PostgreSQLSSLWillingPacket;
import org.apache.shardingsphere.db.protocol.postgresql.packet.handshake.authentication.PostgreSQLMD5PasswordAuthenticationPacket;
import org.apache.shardingsphere.db.protocol.postgresql.payload.PostgreSQLPacketPayload;
import org.apache.shardingsphere.dialect.postgresql.exception.authority.EmptyUsernameException;
@@ -44,6 +47,7 @@ import org.apache.shardingsphere.metadata.persist.MetaDataPersistService;
import org.apache.shardingsphere.proxy.backend.context.ProxyContext;
import org.apache.shardingsphere.proxy.frontend.authentication.AuthenticationResult;
import org.apache.shardingsphere.proxy.frontend.postgresql.authentication.authenticator.impl.PostgreSQLMD5PasswordAuthenticator;
+import org.apache.shardingsphere.proxy.frontend.ssl.ProxySSLContext;
import org.apache.shardingsphere.test.mock.AutoMockExtension;
import org.apache.shardingsphere.test.mock.StaticMockSettings;
import org.junit.jupiter.api.BeforeEach;
@@ -63,13 +67,15 @@ import static org.hamcrest.CoreMatchers.is;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertThrows;
+import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.RETURNS_DEEP_STUBS;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
@ExtendWith(AutoMockExtension.class)
-@StaticMockSettings(ProxyContext.class)
+@StaticMockSettings({ProxyContext.class, ProxySSLContext.class})
class PostgreSQLAuthenticationEngineTest {
private final String username = "root";
@@ -86,12 +92,28 @@ class PostgreSQLAuthenticationEngineTest {
}
@Test
- void assertSSLNegative() {
+ void assertSSLUnwilling() {
ByteBuf byteBuf = createByteBuf(8, 8);
byteBuf.writeInt(8);
byteBuf.writeInt(80877103);
PacketPayload payload = new PostgreSQLPacketPayload(byteBuf, StandardCharsets.UTF_8);
- AuthenticationResult actual = new PostgreSQLAuthenticationEngine().authenticate(mock(ChannelHandlerContext.class), payload);
+ ChannelHandlerContext context = mock(ChannelHandlerContext.class);
+ AuthenticationResult actual = new PostgreSQLAuthenticationEngine().authenticate(context, payload);
+ verify(context).writeAndFlush(any(PostgreSQLSSLUnwillingPacket.class));
+ assertFalse(actual.isFinished());
+ }
+
+ @Test
+ void assertSSLWilling() {
+ ByteBuf byteBuf = createByteBuf(8, 8);
+ byteBuf.writeInt(8);
+ byteBuf.writeInt(80877103);
+ PacketPayload payload = new PostgreSQLPacketPayload(byteBuf, StandardCharsets.UTF_8);
+ ChannelHandlerContext context = mock(ChannelHandlerContext.class, RETURNS_DEEP_STUBS);
+ when(ProxySSLContext.getInstance().isSSLEnabled()).thenReturn(true);
+ AuthenticationResult actual = new PostgreSQLAuthenticationEngine().authenticate(context, payload);
+ verify(context).writeAndFlush(any(PostgreSQLSSLWillingPacket.class));
+ verify(context.pipeline()).addFirst(eq(SslHandler.class.getSimpleName()), any(SslHandler.class));
assertFalse(actual.isFinished());
}