You are viewing a plain text version of this content. The canonical link for it is here.
Posted to notifications@shardingsphere.apache.org by pa...@apache.org on 2023/05/31 11:39:19 UTC

[shardingsphere] branch master updated: Refactor PostgreSQLAggregatedCommandPacket (#25961)

This is an automated email from the ASF dual-hosted git repository.

panjuan pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/shardingsphere.git


The following commit(s) were added to refs/heads/master by this push:
     new 0d7bc500aad Refactor PostgreSQLAggregatedCommandPacket (#25961)
0d7bc500aad is described below

commit 0d7bc500aadf295e16845a07267710b009da23f2
Author: Liang Zhang <zh...@apache.org>
AuthorDate: Wed May 31 19:39:11 2023 +0800

    Refactor PostgreSQLAggregatedCommandPacket (#25961)
---
 .../db/protocol/codec/PacketCodecTest.java         |  1 -
 .../PostgreSQLAggregatedCommandPacket.java         | 57 +++++++++-------------
 .../command/OpenGaussCommandExecutorFactory.java   | 12 ++---
 .../OpenGaussCommandExecutorFactoryTest.java       |  4 +-
 .../command/PostgreSQLCommandExecutorFactory.java  | 12 ++---
 .../PostgreSQLCommandExecutorFactoryTest.java      |  4 +-
 6 files changed, 40 insertions(+), 50 deletions(-)

diff --git a/db-protocol/core/src/test/java/org/apache/shardingsphere/db/protocol/codec/PacketCodecTest.java b/db-protocol/core/src/test/java/org/apache/shardingsphere/db/protocol/codec/PacketCodecTest.java
index d2e0808f3d0..3a9ce8e99a9 100644
--- a/db-protocol/core/src/test/java/org/apache/shardingsphere/db/protocol/codec/PacketCodecTest.java
+++ b/db-protocol/core/src/test/java/org/apache/shardingsphere/db/protocol/codec/PacketCodecTest.java
@@ -68,7 +68,6 @@ class PacketCodecTest {
         verify(databasePacketCodecEngine, times(0)).decode(context, byteBuf, Collections.emptyList());
     }
     
-    @SuppressWarnings("unchecked")
     @Test
     void assertEncode() {
         DatabasePacket databasePacket = mock(DatabasePacket.class);
diff --git a/db-protocol/postgresql/src/main/java/org/apache/shardingsphere/db/protocol/postgresql/packet/command/query/extended/PostgreSQLAggregatedCommandPacket.java b/db-protocol/postgresql/src/main/java/org/apache/shardingsphere/db/protocol/postgresql/packet/command/query/extended/PostgreSQLAggregatedCommandPacket.java
index 21a8218aec5..0cf186eb857 100644
--- a/db-protocol/postgresql/src/main/java/org/apache/shardingsphere/db/protocol/postgresql/packet/command/query/extended/PostgreSQLAggregatedCommandPacket.java
+++ b/db-protocol/postgresql/src/main/java/org/apache/shardingsphere/db/protocol/postgresql/packet/command/query/extended/PostgreSQLAggregatedCommandPacket.java
@@ -17,7 +17,6 @@
 
 package org.apache.shardingsphere.db.protocol.postgresql.packet.command.query.extended;
 
-import com.google.common.base.Preconditions;
 import lombok.Getter;
 import org.apache.shardingsphere.db.protocol.postgresql.packet.command.PostgreSQLCommandPacket;
 import org.apache.shardingsphere.db.protocol.postgresql.packet.command.query.extended.bind.PostgreSQLComBindPacket;
@@ -27,7 +26,6 @@ import org.apache.shardingsphere.db.protocol.postgresql.packet.identifier.Postgr
 import org.apache.shardingsphere.db.protocol.postgresql.payload.PostgreSQLPacketPayload;
 
 import java.util.List;
-import java.util.RandomAccess;
 
 @Getter
 public final class PostgreSQLAggregatedCommandPacket extends PostgreSQLCommandPacket {
@@ -36,38 +34,38 @@ public final class PostgreSQLAggregatedCommandPacket extends PostgreSQLCommandPa
     
     private final boolean containsBatchedStatements;
     
-    private final int firstBindIndex;
+    private final int batchPacketBeginIndex;
     
-    private final int lastExecuteIndex;
+    private final int batchPacketEndIndex;
     
     public PostgreSQLAggregatedCommandPacket(final List<PostgreSQLCommandPacket> packets) {
         this.packets = packets;
-        int parseTimes = 0;
-        int firstStatementBindTimes = 0;
-        int firstStatementExecuteTimes = 0;
-        String firstStatement = null;
+        String firstStatementId = null;
         String firstPortal = null;
+        int parsePacketCount = 0;
+        int bindPacketCountForFirstStatement = 0;
+        int executePacketCountForFirstStatement = 0;
+        int batchPacketBeginIndex = -1;
+        int batchPacketEndIndex = -1;
         int index = 0;
-        int firstBindIndex = -1;
-        int lastExecuteIndex = -1;
         for (PostgreSQLCommandPacket each : packets) {
             if (each instanceof PostgreSQLComParsePacket) {
-                if (++parseTimes > 1) {
+                if (++parsePacketCount > 1) {
                     break;
                 }
-                if (null == firstStatement) {
-                    firstStatement = ((PostgreSQLComParsePacket) each).getStatementId();
-                } else if (!firstStatement.equals(((PostgreSQLComParsePacket) each).getStatementId())) {
+                if (null == firstStatementId) {
+                    firstStatementId = ((PostgreSQLComParsePacket) each).getStatementId();
+                } else if (!firstStatementId.equals(((PostgreSQLComParsePacket) each).getStatementId())) {
                     break;
                 }
             }
             if (each instanceof PostgreSQLComBindPacket) {
-                if (-1 == firstBindIndex) {
-                    firstBindIndex = index;
+                if (-1 == batchPacketBeginIndex) {
+                    batchPacketBeginIndex = index;
                 }
-                if (null == firstStatement) {
-                    firstStatement = ((PostgreSQLComBindPacket) each).getStatementId();
-                } else if (!firstStatement.equals(((PostgreSQLComBindPacket) each).getStatementId())) {
+                if (null == firstStatementId) {
+                    firstStatementId = ((PostgreSQLComBindPacket) each).getStatementId();
+                } else if (!firstStatementId.equals(((PostgreSQLComBindPacket) each).getStatementId())) {
                     break;
                 }
                 if (null == firstPortal) {
@@ -75,31 +73,24 @@ public final class PostgreSQLAggregatedCommandPacket extends PostgreSQLCommandPa
                 } else if (!firstPortal.equals(((PostgreSQLComBindPacket) each).getPortal())) {
                     break;
                 }
-                firstStatementBindTimes++;
+                bindPacketCountForFirstStatement++;
             }
             if (each instanceof PostgreSQLComExecutePacket) {
-                if (index > lastExecuteIndex) {
-                    lastExecuteIndex = index;
+                if (index > batchPacketEndIndex) {
+                    batchPacketEndIndex = index;
                 }
                 if (null == firstPortal) {
                     firstPortal = ((PostgreSQLComExecutePacket) each).getPortal();
                 } else if (!firstPortal.equals(((PostgreSQLComExecutePacket) each).getPortal())) {
                     break;
                 }
-                firstStatementExecuteTimes++;
+                executePacketCountForFirstStatement++;
             }
             index++;
         }
-        this.firstBindIndex = firstBindIndex;
-        this.lastExecuteIndex = lastExecuteIndex;
-        containsBatchedStatements = firstStatementBindTimes == firstStatementExecuteTimes && firstStatementBindTimes >= 3;
-        if (containsBatchedStatements) {
-            ensureRandomAccessible(packets);
-        }
-    }
-    
-    private void ensureRandomAccessible(final List<PostgreSQLCommandPacket> packets) {
-        Preconditions.checkArgument(packets instanceof RandomAccess, "Packets must be RandomAccess.");
+        this.batchPacketBeginIndex = batchPacketBeginIndex;
+        this.batchPacketEndIndex = batchPacketEndIndex;
+        containsBatchedStatements = bindPacketCountForFirstStatement == executePacketCountForFirstStatement && bindPacketCountForFirstStatement >= 3;
     }
     
     @Override
diff --git a/proxy/frontend/type/opengauss/src/main/java/org/apache/shardingsphere/proxy/frontend/opengauss/command/OpenGaussCommandExecutorFactory.java b/proxy/frontend/type/opengauss/src/main/java/org/apache/shardingsphere/proxy/frontend/opengauss/command/OpenGaussCommandExecutorFactory.java
index 2a333e4f580..553e4011df6 100644
--- a/proxy/frontend/type/opengauss/src/main/java/org/apache/shardingsphere/proxy/frontend/opengauss/command/OpenGaussCommandExecutorFactory.java
+++ b/proxy/frontend/type/opengauss/src/main/java/org/apache/shardingsphere/proxy/frontend/opengauss/command/OpenGaussCommandExecutorFactory.java
@@ -95,15 +95,15 @@ public final class OpenGaussCommandExecutorFactory {
     private static List<CommandExecutor> getExecutorsOfAggregatedBatchedStatements(final PostgreSQLAggregatedCommandPacket aggregatedCommandPacket,
                                                                                    final ConnectionSession connectionSession, final PortalContext portalContext) throws SQLException {
         List<PostgreSQLCommandPacket> packets = aggregatedCommandPacket.getPackets();
-        int firstBindIndex = aggregatedCommandPacket.getFirstBindIndex();
-        int lastExecuteIndex = aggregatedCommandPacket.getLastExecuteIndex();
-        List<CommandExecutor> result = new ArrayList<>(firstBindIndex + packets.size() - lastExecuteIndex);
-        for (int i = 0; i < firstBindIndex; i++) {
+        int batchPacketBeginIndex = aggregatedCommandPacket.getBatchPacketBeginIndex();
+        int batchPacketEndIndex = aggregatedCommandPacket.getBatchPacketEndIndex();
+        List<CommandExecutor> result = new ArrayList<>(batchPacketBeginIndex + packets.size() - batchPacketEndIndex);
+        for (int i = 0; i < batchPacketBeginIndex; i++) {
             PostgreSQLCommandPacket each = packets.get(i);
             result.add(getCommandExecutor((CommandPacketType) each.getIdentifier(), each, connectionSession, portalContext));
         }
-        result.add(new PostgreSQLAggregatedBatchedStatementsCommandExecutor(connectionSession, packets.subList(firstBindIndex, lastExecuteIndex + 1)));
-        for (int i = lastExecuteIndex + 1; i < packets.size(); i++) {
+        result.add(new PostgreSQLAggregatedBatchedStatementsCommandExecutor(connectionSession, packets.subList(batchPacketBeginIndex, batchPacketEndIndex + 1)));
+        for (int i = batchPacketEndIndex + 1; i < packets.size(); i++) {
             PostgreSQLCommandPacket each = packets.get(i);
             result.add(getCommandExecutor((CommandPacketType) each.getIdentifier(), each, connectionSession, portalContext));
         }
diff --git a/proxy/frontend/type/opengauss/src/test/java/org/apache/shardingsphere/proxy/frontend/opengauss/command/OpenGaussCommandExecutorFactoryTest.java b/proxy/frontend/type/opengauss/src/test/java/org/apache/shardingsphere/proxy/frontend/opengauss/command/OpenGaussCommandExecutorFactoryTest.java
index ccb77e4eaf3..f0d01c2519c 100644
--- a/proxy/frontend/type/opengauss/src/test/java/org/apache/shardingsphere/proxy/frontend/opengauss/command/OpenGaussCommandExecutorFactoryTest.java
+++ b/proxy/frontend/type/opengauss/src/test/java/org/apache/shardingsphere/proxy/frontend/opengauss/command/OpenGaussCommandExecutorFactoryTest.java
@@ -135,8 +135,8 @@ class OpenGaussCommandExecutorFactoryTest {
         when(packet.isContainsBatchedStatements()).thenReturn(true);
         when(packet.getPackets()).thenReturn(
                 Arrays.asList(parsePacket, bindPacket, describePacket, executePacket, bindPacket, describePacket, executePacket, closePacket, syncPacket, terminationPacket));
-        when(packet.getFirstBindIndex()).thenReturn(1);
-        when(packet.getLastExecuteIndex()).thenReturn(6);
+        when(packet.getBatchPacketBeginIndex()).thenReturn(1);
+        when(packet.getBatchPacketEndIndex()).thenReturn(6);
         CommandExecutor actual = OpenGaussCommandExecutorFactory.newInstance(null, packet, connectionSession, portalContext);
         assertThat(actual, instanceOf(PostgreSQLAggregatedCommandExecutor.class));
         Iterator<CommandExecutor> actualPacketsIterator = getExecutorsFromAggregatedCommandExecutor((PostgreSQLAggregatedCommandExecutor) actual).iterator();
diff --git a/proxy/frontend/type/postgresql/src/main/java/org/apache/shardingsphere/proxy/frontend/postgresql/command/PostgreSQLCommandExecutorFactory.java b/proxy/frontend/type/postgresql/src/main/java/org/apache/shardingsphere/proxy/frontend/postgresql/command/PostgreSQLCommandExecutorFactory.java
index 9da80efb4b8..20ac3c66998 100644
--- a/proxy/frontend/type/postgresql/src/main/java/org/apache/shardingsphere/proxy/frontend/postgresql/command/PostgreSQLCommandExecutorFactory.java
+++ b/proxy/frontend/type/postgresql/src/main/java/org/apache/shardingsphere/proxy/frontend/postgresql/command/PostgreSQLCommandExecutorFactory.java
@@ -90,15 +90,15 @@ public final class PostgreSQLCommandExecutorFactory {
     private static List<CommandExecutor> getExecutorsOfAggregatedBatchedStatements(final PostgreSQLAggregatedCommandPacket aggregatedCommandPacket,
                                                                                    final ConnectionSession connectionSession, final PortalContext portalContext) throws SQLException {
         List<PostgreSQLCommandPacket> packets = aggregatedCommandPacket.getPackets();
-        int firstBindIndex = aggregatedCommandPacket.getFirstBindIndex();
-        int lastExecuteIndex = aggregatedCommandPacket.getLastExecuteIndex();
-        List<CommandExecutor> result = new ArrayList<>(firstBindIndex + packets.size() - lastExecuteIndex);
-        for (int i = 0; i < firstBindIndex; i++) {
+        int batchPacketBeginIndex = aggregatedCommandPacket.getBatchPacketBeginIndex();
+        int batchPacketEndIndex = aggregatedCommandPacket.getBatchPacketEndIndex();
+        List<CommandExecutor> result = new ArrayList<>(batchPacketBeginIndex + packets.size() - batchPacketEndIndex);
+        for (int i = 0; i < batchPacketBeginIndex; i++) {
             PostgreSQLCommandPacket each = packets.get(i);
             result.add(getCommandExecutor((PostgreSQLCommandPacketType) each.getIdentifier(), each, connectionSession, portalContext));
         }
-        result.add(new PostgreSQLAggregatedBatchedStatementsCommandExecutor(connectionSession, packets.subList(firstBindIndex, lastExecuteIndex + 1)));
-        for (int i = lastExecuteIndex + 1; i < packets.size(); i++) {
+        result.add(new PostgreSQLAggregatedBatchedStatementsCommandExecutor(connectionSession, packets.subList(batchPacketBeginIndex, batchPacketEndIndex + 1)));
+        for (int i = batchPacketEndIndex + 1; i < packets.size(); i++) {
             PostgreSQLCommandPacket each = packets.get(i);
             result.add(getCommandExecutor((PostgreSQLCommandPacketType) each.getIdentifier(), each, connectionSession, portalContext));
         }
diff --git a/proxy/frontend/type/postgresql/src/test/java/org/apache/shardingsphere/proxy/frontend/postgresql/command/PostgreSQLCommandExecutorFactoryTest.java b/proxy/frontend/type/postgresql/src/test/java/org/apache/shardingsphere/proxy/frontend/postgresql/command/PostgreSQLCommandExecutorFactoryTest.java
index 712336bd7fa..b54f3e83214 100644
--- a/proxy/frontend/type/postgresql/src/test/java/org/apache/shardingsphere/proxy/frontend/postgresql/command/PostgreSQLCommandExecutorFactoryTest.java
+++ b/proxy/frontend/type/postgresql/src/test/java/org/apache/shardingsphere/proxy/frontend/postgresql/command/PostgreSQLCommandExecutorFactoryTest.java
@@ -141,8 +141,8 @@ class PostgreSQLCommandExecutorFactoryTest {
         PostgreSQLAggregatedCommandPacket packet = mock(PostgreSQLAggregatedCommandPacket.class);
         when(packet.isContainsBatchedStatements()).thenReturn(true);
         when(packet.getPackets()).thenReturn(Arrays.asList(parsePacket, bindPacket, describePacket, executePacket, bindPacket, describePacket, executePacket, syncPacket));
-        when(packet.getFirstBindIndex()).thenReturn(1);
-        when(packet.getLastExecuteIndex()).thenReturn(6);
+        when(packet.getBatchPacketBeginIndex()).thenReturn(1);
+        when(packet.getBatchPacketEndIndex()).thenReturn(6);
         CommandExecutor actual = PostgreSQLCommandExecutorFactory.newInstance(null, packet, connectionSession, portalContext);
         assertThat(actual, instanceOf(PostgreSQLAggregatedCommandExecutor.class));
         Iterator<CommandExecutor> actualPacketsIterator = getExecutorsFromAggregatedCommandExecutor((PostgreSQLAggregatedCommandExecutor) actual).iterator();