You are viewing a plain text version of this content. The canonical link for it is here.
Posted to notifications@shardingsphere.apache.org by pa...@apache.org on 2023/05/31 11:39:19 UTC
[shardingsphere] branch master updated: Refactor PostgreSQLAggregatedCommandPacket (#25961)
This is an automated email from the ASF dual-hosted git repository.
panjuan pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/shardingsphere.git
The following commit(s) were added to refs/heads/master by this push:
new 0d7bc500aad Refactor PostgreSQLAggregatedCommandPacket (#25961)
0d7bc500aad is described below
commit 0d7bc500aadf295e16845a07267710b009da23f2
Author: Liang Zhang <zh...@apache.org>
AuthorDate: Wed May 31 19:39:11 2023 +0800
Refactor PostgreSQLAggregatedCommandPacket (#25961)
---
.../db/protocol/codec/PacketCodecTest.java | 1 -
.../PostgreSQLAggregatedCommandPacket.java | 57 +++++++++-------------
.../command/OpenGaussCommandExecutorFactory.java | 12 ++---
.../OpenGaussCommandExecutorFactoryTest.java | 4 +-
.../command/PostgreSQLCommandExecutorFactory.java | 12 ++---
.../PostgreSQLCommandExecutorFactoryTest.java | 4 +-
6 files changed, 40 insertions(+), 50 deletions(-)
diff --git a/db-protocol/core/src/test/java/org/apache/shardingsphere/db/protocol/codec/PacketCodecTest.java b/db-protocol/core/src/test/java/org/apache/shardingsphere/db/protocol/codec/PacketCodecTest.java
index d2e0808f3d0..3a9ce8e99a9 100644
--- a/db-protocol/core/src/test/java/org/apache/shardingsphere/db/protocol/codec/PacketCodecTest.java
+++ b/db-protocol/core/src/test/java/org/apache/shardingsphere/db/protocol/codec/PacketCodecTest.java
@@ -68,7 +68,6 @@ class PacketCodecTest {
verify(databasePacketCodecEngine, times(0)).decode(context, byteBuf, Collections.emptyList());
}
- @SuppressWarnings("unchecked")
@Test
void assertEncode() {
DatabasePacket databasePacket = mock(DatabasePacket.class);
diff --git a/db-protocol/postgresql/src/main/java/org/apache/shardingsphere/db/protocol/postgresql/packet/command/query/extended/PostgreSQLAggregatedCommandPacket.java b/db-protocol/postgresql/src/main/java/org/apache/shardingsphere/db/protocol/postgresql/packet/command/query/extended/PostgreSQLAggregatedCommandPacket.java
index 21a8218aec5..0cf186eb857 100644
--- a/db-protocol/postgresql/src/main/java/org/apache/shardingsphere/db/protocol/postgresql/packet/command/query/extended/PostgreSQLAggregatedCommandPacket.java
+++ b/db-protocol/postgresql/src/main/java/org/apache/shardingsphere/db/protocol/postgresql/packet/command/query/extended/PostgreSQLAggregatedCommandPacket.java
@@ -17,7 +17,6 @@
package org.apache.shardingsphere.db.protocol.postgresql.packet.command.query.extended;
-import com.google.common.base.Preconditions;
import lombok.Getter;
import org.apache.shardingsphere.db.protocol.postgresql.packet.command.PostgreSQLCommandPacket;
import org.apache.shardingsphere.db.protocol.postgresql.packet.command.query.extended.bind.PostgreSQLComBindPacket;
@@ -27,7 +26,6 @@ import org.apache.shardingsphere.db.protocol.postgresql.packet.identifier.Postgr
import org.apache.shardingsphere.db.protocol.postgresql.payload.PostgreSQLPacketPayload;
import java.util.List;
-import java.util.RandomAccess;
@Getter
public final class PostgreSQLAggregatedCommandPacket extends PostgreSQLCommandPacket {
@@ -36,38 +34,38 @@ public final class PostgreSQLAggregatedCommandPacket extends PostgreSQLCommandPa
private final boolean containsBatchedStatements;
- private final int firstBindIndex;
+ private final int batchPacketBeginIndex;
- private final int lastExecuteIndex;
+ private final int batchPacketEndIndex;
public PostgreSQLAggregatedCommandPacket(final List<PostgreSQLCommandPacket> packets) {
this.packets = packets;
- int parseTimes = 0;
- int firstStatementBindTimes = 0;
- int firstStatementExecuteTimes = 0;
- String firstStatement = null;
+ String firstStatementId = null;
String firstPortal = null;
+ int parsePacketCount = 0;
+ int bindPacketCountForFirstStatement = 0;
+ int executePacketCountForFirstStatement = 0;
+ int batchPacketBeginIndex = -1;
+ int batchPacketEndIndex = -1;
int index = 0;
- int firstBindIndex = -1;
- int lastExecuteIndex = -1;
for (PostgreSQLCommandPacket each : packets) {
if (each instanceof PostgreSQLComParsePacket) {
- if (++parseTimes > 1) {
+ if (++parsePacketCount > 1) {
break;
}
- if (null == firstStatement) {
- firstStatement = ((PostgreSQLComParsePacket) each).getStatementId();
- } else if (!firstStatement.equals(((PostgreSQLComParsePacket) each).getStatementId())) {
+ if (null == firstStatementId) {
+ firstStatementId = ((PostgreSQLComParsePacket) each).getStatementId();
+ } else if (!firstStatementId.equals(((PostgreSQLComParsePacket) each).getStatementId())) {
break;
}
}
if (each instanceof PostgreSQLComBindPacket) {
- if (-1 == firstBindIndex) {
- firstBindIndex = index;
+ if (-1 == batchPacketBeginIndex) {
+ batchPacketBeginIndex = index;
}
- if (null == firstStatement) {
- firstStatement = ((PostgreSQLComBindPacket) each).getStatementId();
- } else if (!firstStatement.equals(((PostgreSQLComBindPacket) each).getStatementId())) {
+ if (null == firstStatementId) {
+ firstStatementId = ((PostgreSQLComBindPacket) each).getStatementId();
+ } else if (!firstStatementId.equals(((PostgreSQLComBindPacket) each).getStatementId())) {
break;
}
if (null == firstPortal) {
@@ -75,31 +73,24 @@ public final class PostgreSQLAggregatedCommandPacket extends PostgreSQLCommandPa
} else if (!firstPortal.equals(((PostgreSQLComBindPacket) each).getPortal())) {
break;
}
- firstStatementBindTimes++;
+ bindPacketCountForFirstStatement++;
}
if (each instanceof PostgreSQLComExecutePacket) {
- if (index > lastExecuteIndex) {
- lastExecuteIndex = index;
+ if (index > batchPacketEndIndex) {
+ batchPacketEndIndex = index;
}
if (null == firstPortal) {
firstPortal = ((PostgreSQLComExecutePacket) each).getPortal();
} else if (!firstPortal.equals(((PostgreSQLComExecutePacket) each).getPortal())) {
break;
}
- firstStatementExecuteTimes++;
+ executePacketCountForFirstStatement++;
}
index++;
}
- this.firstBindIndex = firstBindIndex;
- this.lastExecuteIndex = lastExecuteIndex;
- containsBatchedStatements = firstStatementBindTimes == firstStatementExecuteTimes && firstStatementBindTimes >= 3;
- if (containsBatchedStatements) {
- ensureRandomAccessible(packets);
- }
- }
-
- private void ensureRandomAccessible(final List<PostgreSQLCommandPacket> packets) {
- Preconditions.checkArgument(packets instanceof RandomAccess, "Packets must be RandomAccess.");
+ this.batchPacketBeginIndex = batchPacketBeginIndex;
+ this.batchPacketEndIndex = batchPacketEndIndex;
+ containsBatchedStatements = bindPacketCountForFirstStatement == executePacketCountForFirstStatement && bindPacketCountForFirstStatement >= 3;
}
@Override
diff --git a/proxy/frontend/type/opengauss/src/main/java/org/apache/shardingsphere/proxy/frontend/opengauss/command/OpenGaussCommandExecutorFactory.java b/proxy/frontend/type/opengauss/src/main/java/org/apache/shardingsphere/proxy/frontend/opengauss/command/OpenGaussCommandExecutorFactory.java
index 2a333e4f580..553e4011df6 100644
--- a/proxy/frontend/type/opengauss/src/main/java/org/apache/shardingsphere/proxy/frontend/opengauss/command/OpenGaussCommandExecutorFactory.java
+++ b/proxy/frontend/type/opengauss/src/main/java/org/apache/shardingsphere/proxy/frontend/opengauss/command/OpenGaussCommandExecutorFactory.java
@@ -95,15 +95,15 @@ public final class OpenGaussCommandExecutorFactory {
private static List<CommandExecutor> getExecutorsOfAggregatedBatchedStatements(final PostgreSQLAggregatedCommandPacket aggregatedCommandPacket,
final ConnectionSession connectionSession, final PortalContext portalContext) throws SQLException {
List<PostgreSQLCommandPacket> packets = aggregatedCommandPacket.getPackets();
- int firstBindIndex = aggregatedCommandPacket.getFirstBindIndex();
- int lastExecuteIndex = aggregatedCommandPacket.getLastExecuteIndex();
- List<CommandExecutor> result = new ArrayList<>(firstBindIndex + packets.size() - lastExecuteIndex);
- for (int i = 0; i < firstBindIndex; i++) {
+ int batchPacketBeginIndex = aggregatedCommandPacket.getBatchPacketBeginIndex();
+ int batchPacketEndIndex = aggregatedCommandPacket.getBatchPacketEndIndex();
+ List<CommandExecutor> result = new ArrayList<>(batchPacketBeginIndex + packets.size() - batchPacketEndIndex);
+ for (int i = 0; i < batchPacketBeginIndex; i++) {
PostgreSQLCommandPacket each = packets.get(i);
result.add(getCommandExecutor((CommandPacketType) each.getIdentifier(), each, connectionSession, portalContext));
}
- result.add(new PostgreSQLAggregatedBatchedStatementsCommandExecutor(connectionSession, packets.subList(firstBindIndex, lastExecuteIndex + 1)));
- for (int i = lastExecuteIndex + 1; i < packets.size(); i++) {
+ result.add(new PostgreSQLAggregatedBatchedStatementsCommandExecutor(connectionSession, packets.subList(batchPacketBeginIndex, batchPacketEndIndex + 1)));
+ for (int i = batchPacketEndIndex + 1; i < packets.size(); i++) {
PostgreSQLCommandPacket each = packets.get(i);
result.add(getCommandExecutor((CommandPacketType) each.getIdentifier(), each, connectionSession, portalContext));
}
diff --git a/proxy/frontend/type/opengauss/src/test/java/org/apache/shardingsphere/proxy/frontend/opengauss/command/OpenGaussCommandExecutorFactoryTest.java b/proxy/frontend/type/opengauss/src/test/java/org/apache/shardingsphere/proxy/frontend/opengauss/command/OpenGaussCommandExecutorFactoryTest.java
index ccb77e4eaf3..f0d01c2519c 100644
--- a/proxy/frontend/type/opengauss/src/test/java/org/apache/shardingsphere/proxy/frontend/opengauss/command/OpenGaussCommandExecutorFactoryTest.java
+++ b/proxy/frontend/type/opengauss/src/test/java/org/apache/shardingsphere/proxy/frontend/opengauss/command/OpenGaussCommandExecutorFactoryTest.java
@@ -135,8 +135,8 @@ class OpenGaussCommandExecutorFactoryTest {
when(packet.isContainsBatchedStatements()).thenReturn(true);
when(packet.getPackets()).thenReturn(
Arrays.asList(parsePacket, bindPacket, describePacket, executePacket, bindPacket, describePacket, executePacket, closePacket, syncPacket, terminationPacket));
- when(packet.getFirstBindIndex()).thenReturn(1);
- when(packet.getLastExecuteIndex()).thenReturn(6);
+ when(packet.getBatchPacketBeginIndex()).thenReturn(1);
+ when(packet.getBatchPacketEndIndex()).thenReturn(6);
CommandExecutor actual = OpenGaussCommandExecutorFactory.newInstance(null, packet, connectionSession, portalContext);
assertThat(actual, instanceOf(PostgreSQLAggregatedCommandExecutor.class));
Iterator<CommandExecutor> actualPacketsIterator = getExecutorsFromAggregatedCommandExecutor((PostgreSQLAggregatedCommandExecutor) actual).iterator();
diff --git a/proxy/frontend/type/postgresql/src/main/java/org/apache/shardingsphere/proxy/frontend/postgresql/command/PostgreSQLCommandExecutorFactory.java b/proxy/frontend/type/postgresql/src/main/java/org/apache/shardingsphere/proxy/frontend/postgresql/command/PostgreSQLCommandExecutorFactory.java
index 9da80efb4b8..20ac3c66998 100644
--- a/proxy/frontend/type/postgresql/src/main/java/org/apache/shardingsphere/proxy/frontend/postgresql/command/PostgreSQLCommandExecutorFactory.java
+++ b/proxy/frontend/type/postgresql/src/main/java/org/apache/shardingsphere/proxy/frontend/postgresql/command/PostgreSQLCommandExecutorFactory.java
@@ -90,15 +90,15 @@ public final class PostgreSQLCommandExecutorFactory {
private static List<CommandExecutor> getExecutorsOfAggregatedBatchedStatements(final PostgreSQLAggregatedCommandPacket aggregatedCommandPacket,
final ConnectionSession connectionSession, final PortalContext portalContext) throws SQLException {
List<PostgreSQLCommandPacket> packets = aggregatedCommandPacket.getPackets();
- int firstBindIndex = aggregatedCommandPacket.getFirstBindIndex();
- int lastExecuteIndex = aggregatedCommandPacket.getLastExecuteIndex();
- List<CommandExecutor> result = new ArrayList<>(firstBindIndex + packets.size() - lastExecuteIndex);
- for (int i = 0; i < firstBindIndex; i++) {
+ int batchPacketBeginIndex = aggregatedCommandPacket.getBatchPacketBeginIndex();
+ int batchPacketEndIndex = aggregatedCommandPacket.getBatchPacketEndIndex();
+ List<CommandExecutor> result = new ArrayList<>(batchPacketBeginIndex + packets.size() - batchPacketEndIndex);
+ for (int i = 0; i < batchPacketBeginIndex; i++) {
PostgreSQLCommandPacket each = packets.get(i);
result.add(getCommandExecutor((PostgreSQLCommandPacketType) each.getIdentifier(), each, connectionSession, portalContext));
}
- result.add(new PostgreSQLAggregatedBatchedStatementsCommandExecutor(connectionSession, packets.subList(firstBindIndex, lastExecuteIndex + 1)));
- for (int i = lastExecuteIndex + 1; i < packets.size(); i++) {
+ result.add(new PostgreSQLAggregatedBatchedStatementsCommandExecutor(connectionSession, packets.subList(batchPacketBeginIndex, batchPacketEndIndex + 1)));
+ for (int i = batchPacketEndIndex + 1; i < packets.size(); i++) {
PostgreSQLCommandPacket each = packets.get(i);
result.add(getCommandExecutor((PostgreSQLCommandPacketType) each.getIdentifier(), each, connectionSession, portalContext));
}
diff --git a/proxy/frontend/type/postgresql/src/test/java/org/apache/shardingsphere/proxy/frontend/postgresql/command/PostgreSQLCommandExecutorFactoryTest.java b/proxy/frontend/type/postgresql/src/test/java/org/apache/shardingsphere/proxy/frontend/postgresql/command/PostgreSQLCommandExecutorFactoryTest.java
index 712336bd7fa..b54f3e83214 100644
--- a/proxy/frontend/type/postgresql/src/test/java/org/apache/shardingsphere/proxy/frontend/postgresql/command/PostgreSQLCommandExecutorFactoryTest.java
+++ b/proxy/frontend/type/postgresql/src/test/java/org/apache/shardingsphere/proxy/frontend/postgresql/command/PostgreSQLCommandExecutorFactoryTest.java
@@ -141,8 +141,8 @@ class PostgreSQLCommandExecutorFactoryTest {
PostgreSQLAggregatedCommandPacket packet = mock(PostgreSQLAggregatedCommandPacket.class);
when(packet.isContainsBatchedStatements()).thenReturn(true);
when(packet.getPackets()).thenReturn(Arrays.asList(parsePacket, bindPacket, describePacket, executePacket, bindPacket, describePacket, executePacket, syncPacket));
- when(packet.getFirstBindIndex()).thenReturn(1);
- when(packet.getLastExecuteIndex()).thenReturn(6);
+ when(packet.getBatchPacketBeginIndex()).thenReturn(1);
+ when(packet.getBatchPacketEndIndex()).thenReturn(6);
CommandExecutor actual = PostgreSQLCommandExecutorFactory.newInstance(null, packet, connectionSession, portalContext);
assertThat(actual, instanceOf(PostgreSQLAggregatedCommandExecutor.class));
Iterator<CommandExecutor> actualPacketsIterator = getExecutorsFromAggregatedCommandExecutor((PostgreSQLAggregatedCommandExecutor) actual).iterator();