You are viewing a plain text version of this content. The canonical link for it is here.
Posted to notifications@shardingsphere.apache.org by wu...@apache.org on 2023/06/03 11:22:24 UTC

[shardingsphere] branch master updated: re-order pg parameters in jdbc style (#25988)

This is an automated email from the ASF dual-hosted git repository.

wuweijie pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/shardingsphere.git


The following commit(s) were added to refs/heads/master by this push:
     new 4f989592509 re-order pg parameters in jdbc style (#25988)
4f989592509 is described below

commit 4f9895925092b18a654aa6efa5976f689af9a0ae
Author: 亥时 <be...@gmail.com>
AuthorDate: Sat Jun 3 19:22:12 2023 +0800

    re-order pg parameters in jdbc style (#25988)
---
 .../PostgreSQLServerPreparedStatement.java         | 36 ++++++++++++++++++++--
 .../extended/bind/PostgreSQLComBindExecutor.java   |  4 ++-
 .../extended/parse/PostgreSQLComParseExecutor.java | 11 ++++---
 .../bind/PostgreSQLComBindExecutorTest.java        | 33 +++++++++++++++++++-
 .../parse/PostgreSQLComParseExecutorTest.java      | 18 +++++++++++
 5 files changed, 94 insertions(+), 8 deletions(-)

diff --git a/proxy/frontend/type/postgresql/src/main/java/org/apache/shardingsphere/proxy/frontend/postgresql/command/query/extended/PostgreSQLServerPreparedStatement.java b/proxy/frontend/type/postgresql/src/main/java/org/apache/shardingsphere/proxy/frontend/postgresql/command/query/extended/PostgreSQLServerPreparedStatement.java
index 68a28ef8160..38396559ee6 100644
--- a/proxy/frontend/type/postgresql/src/main/java/org/apache/shardingsphere/proxy/frontend/postgresql/command/query/extended/PostgreSQLServerPreparedStatement.java
+++ b/proxy/frontend/type/postgresql/src/main/java/org/apache/shardingsphere/proxy/frontend/postgresql/command/query/extended/PostgreSQLServerPreparedStatement.java
@@ -19,7 +19,6 @@ package org.apache.shardingsphere.proxy.frontend.postgresql.command.query.extend
 
 import lombok.AccessLevel;
 import lombok.Getter;
-import lombok.RequiredArgsConstructor;
 import lombok.Setter;
 import org.apache.shardingsphere.db.protocol.postgresql.packet.PostgreSQLPacket;
 import org.apache.shardingsphere.db.protocol.postgresql.packet.command.query.PostgreSQLParameterDescriptionPacket;
@@ -27,13 +26,14 @@ import org.apache.shardingsphere.db.protocol.postgresql.packet.command.query.ext
 import org.apache.shardingsphere.infra.binder.statement.SQLStatementContext;
 import org.apache.shardingsphere.proxy.backend.session.ServerPreparedStatement;
 
+import java.util.ArrayList;
+import java.util.Collections;
 import java.util.List;
 import java.util.Optional;
 
 /**
  * Prepared statement for PostgreSQL.
  */
-@RequiredArgsConstructor
 @Getter
 @Setter
 public final class PostgreSQLServerPreparedStatement implements ServerPreparedStatement {
@@ -44,9 +44,25 @@ public final class PostgreSQLServerPreparedStatement implements ServerPreparedSt
     
     private final List<PostgreSQLColumnType> parameterTypes;
     
+    private final List<Integer> actualParameterMarkerIndexes;
+    
     @Getter(AccessLevel.NONE)
     private PostgreSQLPacket rowDescription;
     
+    public PostgreSQLServerPreparedStatement(final String sql, final SQLStatementContext sqlStatementContext, final List<PostgreSQLColumnType> parameterTypes) {
+        this(sql, sqlStatementContext, parameterTypes, Collections.emptyList());
+    }
+    
+    public PostgreSQLServerPreparedStatement(final String sql,
+                                             final SQLStatementContext sqlStatementContext,
+                                             final List<PostgreSQLColumnType> parameterTypes,
+                                             final List<Integer> actualParameterMarkerIndexes) {
+        this.sql = sql;
+        this.sqlStatementContext = sqlStatementContext;
+        this.parameterTypes = parameterTypes;
+        this.actualParameterMarkerIndexes = actualParameterMarkerIndexes;
+    }
+    
     /**
      * Describe parameters of the prepared statement.
      *
@@ -64,4 +80,20 @@ public final class PostgreSQLServerPreparedStatement implements ServerPreparedSt
     public Optional<PostgreSQLPacket> describeRows() {
         return Optional.ofNullable(rowDescription);
     }
+    
+    /**
+     * Adjust Parameters order.
+     * @param parameters parameters in pg marker index order
+     * @return parameters in jdbc style marker index order
+     */
+    public List<Object> adjustParametersOrder(final List<Object> parameters) {
+        if (parameters == null || parameters.size() == 0) {
+            return parameters;
+        }
+        List<Object> reOrdered = new ArrayList<>(parameters.size());
+        for (Integer parameterMarkerIndex : actualParameterMarkerIndexes) {
+            reOrdered.add(parameters.get(parameterMarkerIndex));
+        }
+        return reOrdered;
+    }
 }
diff --git a/proxy/frontend/type/postgresql/src/main/java/org/apache/shardingsphere/proxy/frontend/postgresql/command/query/extended/bind/PostgreSQLComBindExecutor.java b/proxy/frontend/type/postgresql/src/main/java/org/apache/shardingsphere/proxy/frontend/postgresql/command/query/extended/bind/PostgreSQLComBindExecutor.java
index e13b3b44b78..bb02e9ea49e 100644
--- a/proxy/frontend/type/postgresql/src/main/java/org/apache/shardingsphere/proxy/frontend/postgresql/command/query/extended/bind/PostgreSQLComBindExecutor.java
+++ b/proxy/frontend/type/postgresql/src/main/java/org/apache/shardingsphere/proxy/frontend/postgresql/command/query/extended/bind/PostgreSQLComBindExecutor.java
@@ -31,6 +31,7 @@ import org.apache.shardingsphere.proxy.frontend.postgresql.command.query.extende
 import java.sql.SQLException;
 import java.util.Collection;
 import java.util.Collections;
+import java.util.List;
 
 /**
  * Command bind executor for PostgreSQL.
@@ -48,7 +49,8 @@ public final class PostgreSQLComBindExecutor implements CommandExecutor {
     public Collection<DatabasePacket> execute() throws SQLException {
         PostgreSQLServerPreparedStatement preparedStatement = connectionSession.getServerPreparedStatementRegistry().getPreparedStatement(packet.getStatementId());
         ProxyDatabaseConnectionManager databaseConnectionManager = connectionSession.getDatabaseConnectionManager();
-        Portal portal = new Portal(packet.getPortal(), preparedStatement, packet.readParameters(preparedStatement.getParameterTypes()), packet.readResultFormats(), databaseConnectionManager);
+        List<Object> parameters = preparedStatement.adjustParametersOrder(packet.readParameters(preparedStatement.getParameterTypes()));
+        Portal portal = new Portal(packet.getPortal(), preparedStatement, parameters, packet.readResultFormats(), databaseConnectionManager);
         portalContext.add(portal);
         portal.bind();
         return Collections.singleton(PostgreSQLBindCompletePacket.getInstance());
diff --git a/proxy/frontend/type/postgresql/src/main/java/org/apache/shardingsphere/proxy/frontend/postgresql/command/query/extended/parse/PostgreSQLComParseExecutor.java b/proxy/frontend/type/postgresql/src/main/java/org/apache/shardingsphere/proxy/frontend/postgresql/command/query/extended/parse/PostgreSQLComParseExecutor.java
index d132ce68ef4..cb6d5b1f7aa 100644
--- a/proxy/frontend/type/postgresql/src/main/java/org/apache/shardingsphere/proxy/frontend/postgresql/command/query/extended/parse/PostgreSQLComParseExecutor.java
+++ b/proxy/frontend/type/postgresql/src/main/java/org/apache/shardingsphere/proxy/frontend/postgresql/command/query/extended/parse/PostgreSQLComParseExecutor.java
@@ -45,6 +45,7 @@ import java.util.Collection;
 import java.util.Collections;
 import java.util.Comparator;
 import java.util.List;
+import java.util.stream.Collectors;
 
 /**
  * PostgreSQL command parse executor.
@@ -61,15 +62,18 @@ public final class PostgreSQLComParseExecutor implements CommandExecutor {
         SQLParserEngine sqlParserEngine = createShardingSphereSQLParserEngine(connectionSession.getDatabaseName());
         String sql = packet.getSQL();
         SQLStatement sqlStatement = sqlParserEngine.parse(sql, true);
+        List<Integer> actualParameterMarkerIndexes = new ArrayList<>();
         if (sqlStatement.getParameterCount() > 0) {
-            sql = convertSQLToJDBCStyle(sqlStatement, sql);
+            List<ParameterMarkerSegment> parameterMarkerSegments = new ArrayList<>(((AbstractSQLStatement) sqlStatement).getParameterMarkerSegments());
+            actualParameterMarkerIndexes.addAll(parameterMarkerSegments.stream().map(ParameterMarkerSegment::getParameterIndex).collect(Collectors.toList()));
+            sql = convertSQLToJDBCStyle(parameterMarkerSegments, sql);
             sqlStatement = sqlParserEngine.parse(sql, true);
         }
         List<PostgreSQLColumnType> paddedColumnTypes = paddingColumnTypes(sqlStatement.getParameterCount(), packet.readParameterTypes());
         SQLStatementContext sqlStatementContext = sqlStatement instanceof DistSQLStatement ? new DistSQLStatementContext((DistSQLStatement) sqlStatement)
                 : SQLStatementContextFactory.newInstance(ProxyContext.getInstance().getContextManager().getMetaDataContexts().getMetaData(),
                         sqlStatement, connectionSession.getDefaultDatabaseName());
-        PostgreSQLServerPreparedStatement serverPreparedStatement = new PostgreSQLServerPreparedStatement(sql, sqlStatementContext, paddedColumnTypes);
+        PostgreSQLServerPreparedStatement serverPreparedStatement = new PostgreSQLServerPreparedStatement(sql, sqlStatementContext, paddedColumnTypes, actualParameterMarkerIndexes);
         connectionSession.getServerPreparedStatementRegistry().addPreparedStatement(packet.getStatementId(), serverPreparedStatement);
         return Collections.singleton(PostgreSQLParseCompletePacket.getInstance());
     }
@@ -80,8 +84,7 @@ public final class PostgreSQLComParseExecutor implements CommandExecutor {
         return sqlParserRule.getSQLParserEngine(DatabaseTypeEngine.getTrunkDatabaseTypeName(metaDataContexts.getMetaData().getDatabase(databaseName).getProtocolType()));
     }
     
-    private String convertSQLToJDBCStyle(final SQLStatement sqlStatement, final String sql) {
-        List<ParameterMarkerSegment> parameterMarkerSegments = new ArrayList<>(((AbstractSQLStatement) sqlStatement).getParameterMarkerSegments());
+    private String convertSQLToJDBCStyle(final List<ParameterMarkerSegment> parameterMarkerSegments, final String sql) {
         parameterMarkerSegments.sort(Comparator.comparingInt(SQLSegment::getStopIndex));
         StringBuilder result = new StringBuilder(sql);
         for (int i = parameterMarkerSegments.size() - 1; i >= 0; i--) {
diff --git a/proxy/frontend/type/postgresql/src/test/java/org/apache/shardingsphere/proxy/frontend/postgresql/command/query/extended/bind/PostgreSQLComBindExecutorTest.java b/proxy/frontend/type/postgresql/src/test/java/org/apache/shardingsphere/proxy/frontend/postgresql/command/query/extended/bind/PostgreSQLComBindExecutorTest.java
index 0c0c0a819c9..1e746c51120 100644
--- a/proxy/frontend/type/postgresql/src/test/java/org/apache/shardingsphere/proxy/frontend/postgresql/command/query/extended/bind/PostgreSQLComBindExecutorTest.java
+++ b/proxy/frontend/type/postgresql/src/test/java/org/apache/shardingsphere/proxy/frontend/postgresql/command/query/extended/bind/PostgreSQLComBindExecutorTest.java
@@ -18,6 +18,7 @@
 package org.apache.shardingsphere.proxy.frontend.postgresql.command.query.extended.bind;
 
 import org.apache.shardingsphere.db.protocol.packet.DatabasePacket;
+import org.apache.shardingsphere.db.protocol.postgresql.packet.command.query.extended.PostgreSQLColumnType;
 import org.apache.shardingsphere.db.protocol.postgresql.packet.command.query.extended.bind.PostgreSQLBindCompletePacket;
 import org.apache.shardingsphere.db.protocol.postgresql.packet.command.query.extended.bind.PostgreSQLComBindPacket;
 import org.apache.shardingsphere.infra.binder.statement.UnknownSQLStatementContext;
@@ -42,8 +43,10 @@ import org.mockito.InjectMocks;
 import org.mockito.Mock;
 
 import java.sql.SQLException;
+import java.util.Arrays;
 import java.util.Collection;
 import java.util.Collections;
+import java.util.List;
 
 import static org.hamcrest.CoreMatchers.is;
 import static org.hamcrest.MatcherAssert.assertThat;
@@ -63,7 +66,7 @@ class PostgreSQLComBindExecutorTest {
     @Mock
     private PostgreSQLComBindPacket bindPacket;
     
-    @Mock
+    @Mock(answer = Answers.CALLS_REAL_METHODS)
     private ConnectionSession connectionSession;
     
     @InjectMocks
@@ -94,4 +97,32 @@ class PostgreSQLComBindExecutorTest {
         assertThat(actual.iterator().next(), is(PostgreSQLBindCompletePacket.getInstance()));
         verify(portalContext).add(any(Portal.class));
     }
+    
+    @Test
+    void assertExecuteBindParameters() throws SQLException {
+        String databaseName = "postgres";
+        ShardingSphereDatabase database = mock(ShardingSphereDatabase.class);
+        when(database.getProtocolType()).thenReturn(new PostgreSQLDatabaseType());
+        when(connectionSession.getServerPreparedStatementRegistry()).thenReturn(new ServerPreparedStatementRegistry());
+        ProxyDatabaseConnectionManager databaseConnectionManager = mock(ProxyDatabaseConnectionManager.class);
+        when(databaseConnectionManager.getConnectionSession()).thenReturn(connectionSession);
+        when(connectionSession.getDatabaseConnectionManager()).thenReturn(databaseConnectionManager);
+        when(connectionSession.getDefaultDatabaseName()).thenReturn(databaseName);
+        String statementId = "S_1";
+        List<Object> parameters = Arrays.asList(1, "updated_name");
+        PostgreSQLServerPreparedStatement serverPreparedStatement = new PostgreSQLServerPreparedStatement("update test set name = $2 where id = $1",
+                new UnknownSQLStatementContext(new PostgreSQLEmptyStatement()),
+                Arrays.asList(PostgreSQLColumnType.VARCHAR, PostgreSQLColumnType.INT4),
+                Arrays.asList(1, 0));
+        connectionSession.getServerPreparedStatementRegistry().addPreparedStatement(statementId, serverPreparedStatement);
+        when(bindPacket.getStatementId()).thenReturn(statementId);
+        when(bindPacket.getPortal()).thenReturn("C_1");
+        when(bindPacket.readParameters(anyList())).thenReturn(parameters);
+        when(bindPacket.readResultFormats()).thenReturn(Collections.emptyList());
+        ContextManager contextManager = mock(ContextManager.class, Answers.RETURNS_DEEP_STUBS);
+        when(ProxyContext.getInstance().getContextManager()).thenReturn(contextManager);
+        when(ProxyContext.getInstance().getDatabase(databaseName)).thenReturn(database);
+        executor.execute();
+        assertThat(connectionSession.getQueryContext().getParameters(), is(Arrays.asList(parameters.get(1), parameters.get(0))));
+    }
 }
diff --git a/proxy/frontend/type/postgresql/src/test/java/org/apache/shardingsphere/proxy/frontend/postgresql/command/query/extended/parse/PostgreSQLComParseExecutorTest.java b/proxy/frontend/type/postgresql/src/test/java/org/apache/shardingsphere/proxy/frontend/postgresql/command/query/extended/parse/PostgreSQLComParseExecutorTest.java
index 9667e79c445..ada9e711e68 100644
--- a/proxy/frontend/type/postgresql/src/test/java/org/apache/shardingsphere/proxy/frontend/postgresql/command/query/extended/parse/PostgreSQLComParseExecutorTest.java
+++ b/proxy/frontend/type/postgresql/src/test/java/org/apache/shardingsphere/proxy/frontend/postgresql/command/query/extended/parse/PostgreSQLComParseExecutorTest.java
@@ -115,6 +115,24 @@ class PostgreSQLComParseExecutorTest {
         assertThat(actualPreparedStatement.getParameterTypes(), is(Arrays.asList(PostgreSQLColumnType.INT4, PostgreSQLColumnType.UNSPECIFIED)));
     }
     
+    @Test
+    void assetExecuteWithNonOrderedParameterizedSQL() throws ReflectiveOperationException {
+        final String rawSQL = "update t_test set name=$2 where id=$1";
+        final String expectedSQL = "update t_test set name=? where id=?";
+        final String statementId = "S_2";
+        when(parsePacket.getSQL()).thenReturn(rawSQL);
+        when(parsePacket.getStatementId()).thenReturn(statementId);
+        when(parsePacket.readParameterTypes()).thenReturn(Arrays.asList(PostgreSQLColumnType.JSON, PostgreSQLColumnType.INT4));
+        Plugins.getMemberAccessor().set(PostgreSQLComParseExecutor.class.getDeclaredField("connectionSession"), executor, connectionSession);
+        ContextManager contextManager = mockContextManager();
+        when(ProxyContext.getInstance().getContextManager()).thenReturn(contextManager);
+        executor.execute();
+        PostgreSQLServerPreparedStatement actualPreparedStatement = connectionSession.getServerPreparedStatementRegistry().getPreparedStatement(statementId);
+        assertThat(actualPreparedStatement.getSql(), is(expectedSQL));
+        assertThat(actualPreparedStatement.getParameterTypes(), is(Arrays.asList(PostgreSQLColumnType.JSON, PostgreSQLColumnType.INT4)));
+        assertThat(actualPreparedStatement.getActualParameterMarkerIndexes(), is(Arrays.asList(1, 0)));
+    }
+    
     @Test
     void assertExecuteWithDistSQL() {
         String sql = "SHOW DIST VARIABLE WHERE NAME = sql_show";