You are viewing a plain text version of this content. The canonical link for it is here.
Posted to notifications@shardingsphere.apache.org by wu...@apache.org on 2021/04/14 16:53:44 UTC

[shardingsphere] branch master updated: Refactor PostgreSQLComStartupPacket (#10091)

This is an automated email from the ASF dual-hosted git repository.

wuweijie pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/shardingsphere.git


The following commit(s) were added to refs/heads/master by this push:
     new d312e29  Refactor PostgreSQLComStartupPacket (#10091)
d312e29 is described below

commit d312e297b1233740f45455a28a546948a55ce590
Author: Liang Zhang <te...@163.com>
AuthorDate: Thu Apr 15 00:53:12 2021 +0800

    Refactor PostgreSQLComStartupPacket (#10091)
    
    * Refactor PostgreSQLComStartupPacket
    
    * Refactor PostgreSQLComStartupPacket
    
    * Refactor PostgreSQLAuthenticationEngine
    
    * Refactor PostgreSQLAuthenticationEngine
---
 .../handshake/PostgreSQLComStartupPacket.java      | 26 +++++++--
 .../postgresql/packet/ByteBufTestUtils.java        |  6 +--
 .../generic/PostgreSQLComStartupPacketTest.java    | 61 +++++++++++++++-------
 .../PostgreSQLAuthenticationEngine.java            | 20 +++----
 4 files changed, 74 insertions(+), 39 deletions(-)

diff --git a/shardingsphere-db-protocol/shardingsphere-db-protocol-postgresql/src/main/java/org/apache/shardingsphere/db/protocol/postgresql/packet/handshake/PostgreSQLComStartupPacket.java b/shardingsphere-db-protocol/shardingsphere-db-protocol-postgresql/src/main/java/org/apache/shardingsphere/db/protocol/postgresql/packet/handshake/PostgreSQLComStartupPacket.java
index 7a86107..7fcedd4 100644
--- a/shardingsphere-db-protocol/shardingsphere-db-protocol-postgresql/src/main/java/org/apache/shardingsphere/db/protocol/postgresql/packet/handshake/PostgreSQLComStartupPacket.java
+++ b/shardingsphere-db-protocol/shardingsphere-db-protocol-postgresql/src/main/java/org/apache/shardingsphere/db/protocol/postgresql/packet/handshake/PostgreSQLComStartupPacket.java
@@ -17,7 +17,6 @@
 
 package org.apache.shardingsphere.db.protocol.postgresql.packet.handshake;
 
-import lombok.Getter;
 import org.apache.shardingsphere.db.protocol.postgresql.packet.PostgreSQLPacket;
 import org.apache.shardingsphere.db.protocol.postgresql.payload.PostgreSQLPacketPayload;
 
@@ -27,10 +26,13 @@ import java.util.Map;
 /**
  * Startup packet for PostgreSQL.
  */
-@Getter
 public final class PostgreSQLComStartupPacket implements PostgreSQLPacket {
     
-    private final Map<String, String> parametersMap = new HashMap<>(16, 1);
+    private static final String DATABASE_NAME_KEY = "database";
+    
+    private static final String USER_NAME_KEY = "user";
+    
+    private final Map<String, String> parametersMap = new HashMap<>();
     
     public PostgreSQLComStartupPacket(final PostgreSQLPacketPayload payload) {
         payload.skipReserved(8);
@@ -39,6 +41,24 @@ public final class PostgreSQLComStartupPacket implements PostgreSQLPacket {
         }
     }
     
+    /**
+     * Get database.
+     * 
+     * @return database
+     */
+    public String getDatabase() {
+        return parametersMap.get(DATABASE_NAME_KEY);
+    }
+    
+    /**
+     * Get user.
+     * 
+     * @return user
+     */
+    public String getUser() {
+        return parametersMap.get(USER_NAME_KEY);
+    }
+    
     @Override
     public void write(final PostgreSQLPacketPayload payload) {
     }
diff --git a/shardingsphere-db-protocol/shardingsphere-db-protocol-postgresql/src/test/java/org/apache/shardingsphere/db/protocol/postgresql/packet/ByteBufTestUtils.java b/shardingsphere-db-protocol/shardingsphere-db-protocol-postgresql/src/test/java/org/apache/shardingsphere/db/protocol/postgresql/packet/ByteBufTestUtils.java
index a84bbce..f3249c3 100644
--- a/shardingsphere-db-protocol/shardingsphere-db-protocol-postgresql/src/test/java/org/apache/shardingsphere/db/protocol/postgresql/packet/ByteBufTestUtils.java
+++ b/shardingsphere-db-protocol/shardingsphere-db-protocol-postgresql/src/test/java/org/apache/shardingsphere/db/protocol/postgresql/packet/ByteBufTestUtils.java
@@ -30,7 +30,7 @@ public final class ByteBufTestUtils {
      * Creates a new buffer with a newly allocated byte array, fixed capacity.
      *
      * @param capacity the fixed capacity of the underlying byte array
-     * @return ByteBuf
+     * @return byte buffer
      */
     public static ByteBuf createByteBuf(final int capacity) {
         return createByteBuf(capacity, capacity);
@@ -40,8 +40,8 @@ public final class ByteBufTestUtils {
      * Creates a new buffer with a newly allocated byte array.
      *
      * @param initialCapacity the initial capacity of the underlying byte array
-     * @param maxCapacity     the max capacity of the underlying byte array
-     * @return ByteBuf
+     * @param maxCapacity the max capacity of the underlying byte array
+     * @return byte buffer
      */
     public static ByteBuf createByteBuf(final int initialCapacity, final int maxCapacity) {
         UnpooledByteBufAllocator byteBufAllocator = UnpooledByteBufAllocator.DEFAULT;
diff --git a/shardingsphere-db-protocol/shardingsphere-db-protocol-postgresql/src/test/java/org/apache/shardingsphere/db/protocol/postgresql/packet/generic/PostgreSQLComStartupPacketTest.java b/shardingsphere-db-protocol/shardingsphere-db-protocol-postgresql/src/test/java/org/apache/shardingsphere/db/protocol/postgresql/packet/generic/PostgreSQLComStartupPacketTest.java
index 44130f3..e2aa129 100644
--- a/shardingsphere-db-protocol/shardingsphere-db-protocol-postgresql/src/test/java/org/apache/shardingsphere/db/protocol/postgresql/packet/generic/PostgreSQLComStartupPacketTest.java
+++ b/shardingsphere-db-protocol/shardingsphere-db-protocol-postgresql/src/test/java/org/apache/shardingsphere/db/protocol/postgresql/packet/generic/PostgreSQLComStartupPacketTest.java
@@ -25,34 +25,55 @@ import org.junit.Test;
 
 import java.util.LinkedHashMap;
 import java.util.Map;
+import java.util.Map.Entry;
 
 import static org.hamcrest.CoreMatchers.is;
 import static org.junit.Assert.assertThat;
+import static org.mockito.Mockito.mock;
 
 public final class PostgreSQLComStartupPacketTest {
     
     @Test
-    public void assertReadWrite() {
-        Map<String, String> expectedParametersMap = new LinkedHashMap<>();
-        expectedParametersMap.put("user", "postgres");
-        expectedParametersMap.put("database", "postgres");
-        int expectedLength = 4 + 4;
-        for (Map.Entry<String, String> each : expectedParametersMap.entrySet()) {
-            expectedLength += each.getKey().length() + 1;
-            expectedLength += each.getValue().length() + 1;
+    public void assertNewPostgreSQLComStartupPacket() {
+        Map<String, String> parametersMap = createParametersMap();
+        int packetMessageLength = getPacketMessageLength(parametersMap);
+        ByteBuf byteBuf = ByteBufTestUtils.createByteBuf(packetMessageLength);
+        PostgreSQLPacketPayload payload = createPayload(parametersMap, packetMessageLength, byteBuf);
+        PostgreSQLComStartupPacket actual = new PostgreSQLComStartupPacket(payload);
+        assertThat(actual.getDatabase(), is("test_db"));
+        assertThat(actual.getUser(), is("postgres"));
+        assertThat(byteBuf.writerIndex(), is(packetMessageLength));
+    }
+    
+    private Map<String, String> createParametersMap() {
+        Map<String, String> result = new LinkedHashMap<>(2, 1);
+        result.put("database", "test_db");
+        result.put("user", "postgres");
+        return result;
+    }
+    
+    private int getPacketMessageLength(final Map<String, String> parametersMap) {
+        int result = 4 + 4;
+        for (Entry<String, String> entry : parametersMap.entrySet()) {
+            result += entry.getKey().length() + 1;
+            result += entry.getValue().length() + 1;
         }
-        ByteBuf byteBuf = ByteBufTestUtils.createByteBuf(expectedLength);
-        PostgreSQLPacketPayload payload = new PostgreSQLPacketPayload(byteBuf);
-        payload.writeInt4(expectedLength);
-        payload.writeInt4(196608);
-        for (Map.Entry<String, String> each : expectedParametersMap.entrySet()) {
-            payload.writeStringNul(each.getKey());
-            payload.writeStringNul(each.getValue());
+        return result;
+    }
+    
+    private PostgreSQLPacketPayload createPayload(final Map<String, String> actualParametersMap, final int actualMessageLength, final ByteBuf byteBuf) {
+        PostgreSQLPacketPayload result = new PostgreSQLPacketPayload(byteBuf);
+        result.writeInt4(actualMessageLength);
+        result.writeInt4(196608);
+        for (Entry<String, String> entry : actualParametersMap.entrySet()) {
+            result.writeStringNul(entry.getKey());
+            result.writeStringNul(entry.getValue());
         }
-        PostgreSQLComStartupPacket packet = new PostgreSQLComStartupPacket(payload);
-        Map<String, String> actualParametersMap = packet.getParametersMap();
-        assertThat(actualParametersMap, is(expectedParametersMap));
-        packet.write(payload);
-        assertThat(byteBuf.writerIndex(), is(expectedLength));
+        return result;
+    }
+    
+    @Test
+    public void assertWrite() {
+        new PostgreSQLComStartupPacket(mock(PostgreSQLPacketPayload.class)).write(mock(PostgreSQLPacketPayload.class));
     }
 }
diff --git a/shardingsphere-proxy/shardingsphere-proxy-frontend/shardingsphere-proxy-frontend-postgresql/src/main/java/org/apache/shardingsphere/proxy/frontend/postgresql/authentication/PostgreSQLAuthenticationEngine.java b/shardingsphere-proxy/shardingsphere-proxy-frontend/shardingsphere-proxy-frontend-postgresql/src/main/java/org/apache/shardingsphere/proxy/frontend/postgresql/authentication/PostgreSQLAuthenticationEngine.java
index 4a51655..33950dd 100644
--- a/shardingsphere-proxy/shardingsphere-proxy-frontend/shardingsphere-proxy-frontend-postgresql/src/main/java/org/apache/shardingsphere/proxy/frontend/postgresql/authentication/PostgreSQLAuthenticationEngine.java
+++ b/shardingsphere-proxy/shardingsphere-proxy-frontend/shardingsphere-proxy-frontend-postgresql/src/main/java/org/apache/shardingsphere/proxy/frontend/postgresql/authentication/PostgreSQLAuthenticationEngine.java
@@ -49,10 +49,6 @@ public final class PostgreSQLAuthenticationEngine implements AuthenticationEngin
     
     private static final int SSL_REQUEST_CODE = 80877103;
     
-    private static final String USER_NAME_KEYWORD = "user";
-    
-    private static final String DATABASE_NAME_KEYWORD = "database";
-    
     private final AtomicBoolean startupMessageReceived = new AtomicBoolean(false);
     
     private volatile byte[] md5Salt;
@@ -79,23 +75,21 @@ public final class PostgreSQLAuthenticationEngine implements AuthenticationEngin
     private AuthenticationResult beforeStartupMessage(final ChannelHandlerContext context, final PostgreSQLPacketPayload payload) {
         PostgreSQLComStartupPacket comStartupPacket = new PostgreSQLComStartupPacket(payload);
         startupMessageReceived.set(true);
-        String databaseName = comStartupPacket.getParametersMap().get(DATABASE_NAME_KEYWORD);
-        if (!Strings.isNullOrEmpty(databaseName) && !ProxyContext.getInstance().schemaExists(databaseName)) {
-            PostgreSQLErrorResponsePacket responsePacket = createErrorPacket(PostgreSQLErrorCode.INVALID_CATALOG_NAME, String.format("database \"%s\" does not exist", databaseName));
-            context.writeAndFlush(responsePacket);
+        String database = comStartupPacket.getDatabase();
+        if (!Strings.isNullOrEmpty(database) && !ProxyContext.getInstance().schemaExists(database)) {
+            context.writeAndFlush(createErrorPacket(PostgreSQLErrorCode.INVALID_CATALOG_NAME, String.format("database \"%s\" does not exist", database)));
             context.close();
             return AuthenticationResultBuilder.continued();
         }
-        String username = comStartupPacket.getParametersMap().get(USER_NAME_KEYWORD);
-        if (null == username || username.isEmpty()) {
-            PostgreSQLErrorResponsePacket responsePacket = createErrorPacket(PostgreSQLErrorCode.SQLSERVER_REJECTED_ESTABLISHMENT_OF_SQLCONNECTION, "user not set in StartupMessage");
-            context.writeAndFlush(responsePacket);
+        String user = comStartupPacket.getUser();
+        if (Strings.isNullOrEmpty(user)) {
+            context.writeAndFlush(createErrorPacket(PostgreSQLErrorCode.SQLSERVER_REJECTED_ESTABLISHMENT_OF_SQLCONNECTION, "user not set in StartupMessage"));
             context.close();
             return AuthenticationResultBuilder.continued();
         }
         md5Salt = PostgreSQLRandomGenerator.getInstance().generateRandomBytes(4);
         context.writeAndFlush(new PostgreSQLAuthenticationMD5PasswordPacket(md5Salt));
-        currentAuthResult = AuthenticationResultBuilder.continued(username, "", databaseName);
+        currentAuthResult = AuthenticationResultBuilder.continued(user, "", database);
         return currentAuthResult;
     }