You are viewing a plain text version of this content. The canonical link for it is here.
Posted to notifications@shardingsphere.apache.org by pa...@apache.org on 2022/07/28 02:05:15 UTC

[shardingsphere] branch master updated: Reuse SQLStatementContext in JDBCPortal (#19609)

This is an automated email from the ASF dual-hosted git repository.

panjuan pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/shardingsphere.git


The following commit(s) were added to refs/heads/master by this push:
     new 328510046db Reuse SQLStatementContext in JDBCPortal (#19609)
328510046db is described below

commit 328510046db5d84d7971fb94c3a01b3e3e4e2ad7
Author: 吴伟杰 <wu...@apache.org>
AuthorDate: Thu Jul 28 10:05:09 2022 +0800

    Reuse SQLStatementContext in JDBCPortal (#19609)
---
 .../command/query/extended/JDBCPortal.java         |  11 ++-
 .../command/query/extended/JDBCPortalTest.java     | 101 +++++++++++++++------
 2 files changed, 78 insertions(+), 34 deletions(-)

diff --git a/shardingsphere-proxy/shardingsphere-proxy-frontend/shardingsphere-proxy-frontend-postgresql/src/main/java/org/apache/shardingsphere/proxy/frontend/postgresql/command/query/extended/JDBCPortal.java b/shardingsphere-proxy/shardingsphere-proxy-frontend/shardingsphere-proxy-frontend-postgresql/src/main/java/org/apache/shardingsphere/proxy/frontend/postgresql/command/query/extended/JDBCPortal.java
index b5e8fb0310d..e70292db9bd 100644
--- a/shardingsphere-proxy/shardingsphere-proxy-frontend/shardingsphere-proxy-frontend-postgresql/src/main/java/org/apache/shardingsphere/proxy/frontend/postgresql/command/query/extended/JDBCPortal.java
+++ b/shardingsphere-proxy/shardingsphere-proxy-frontend/shardingsphere-proxy-frontend-postgresql/src/main/java/org/apache/shardingsphere/proxy/frontend/postgresql/command/query/extended/JDBCPortal.java
@@ -32,9 +32,8 @@ import org.apache.shardingsphere.db.protocol.postgresql.packet.command.query.ext
 import org.apache.shardingsphere.db.protocol.postgresql.packet.generic.PostgreSQLCommandCompletePacket;
 import org.apache.shardingsphere.db.protocol.postgresql.packet.handshake.PostgreSQLParameterStatusPacket;
 import org.apache.shardingsphere.db.protocol.postgresql.packet.identifier.PostgreSQLIdentifierPacket;
-import org.apache.shardingsphere.distsql.parser.statement.DistSQLStatement;
 import org.apache.shardingsphere.infra.binder.LogicSQL;
-import org.apache.shardingsphere.infra.binder.SQLStatementContextFactory;
+import org.apache.shardingsphere.infra.binder.aware.ParameterAware;
 import org.apache.shardingsphere.infra.binder.statement.SQLStatementContext;
 import org.apache.shardingsphere.infra.database.type.DatabaseType;
 import org.apache.shardingsphere.infra.database.type.DatabaseTypeFactory;
@@ -87,14 +86,16 @@ public final class JDBCPortal implements Portal<Void> {
         this.sqlStatement = preparedStatement.getSqlStatement();
         this.resultFormats = resultFormats;
         this.backendConnection = backendConnection;
-        if (sqlStatement instanceof EmptyStatement || sqlStatement instanceof DistSQLStatement) {
+        if (!preparedStatement.getSqlStatementContext().isPresent()) {
             textProtocolBackendHandler = TextProtocolBackendHandlerFactory.newInstance(DatabaseTypeFactory.getInstance("PostgreSQL"),
                     preparedStatement.getSql(), sqlStatement, backendConnection.getConnectionSession());
             return;
         }
         String databaseName = backendConnection.getConnectionSession().getDefaultDatabaseName();
-        SQLStatementContext<?> sqlStatementContext = SQLStatementContextFactory.newInstance(
-                ProxyContext.getInstance().getContextManager().getMetaDataContexts().getMetaData().getDatabases(), parameters, sqlStatement, databaseName);
+        SQLStatementContext<?> sqlStatementContext = preparedStatement.getSqlStatementContext().get();
+        if (sqlStatementContext instanceof ParameterAware) {
+            ((ParameterAware) sqlStatementContext).setUpParameters(parameters);
+        }
         DatabaseType databaseType = getDatabaseType(databaseName);
         LogicSQL logicSQL = new LogicSQL(sqlStatementContext, preparedStatement.getSql(), parameters);
         textProtocolBackendHandler = TextProtocolBackendHandlerFactory.newInstance(databaseType, logicSQL, backendConnection.getConnectionSession(), true);
diff --git a/shardingsphere-proxy/shardingsphere-proxy-frontend/shardingsphere-proxy-frontend-postgresql/src/test/java/org/apache/shardingsphere/proxy/frontend/postgresql/command/query/extended/JDBCPortalTest.java b/shardingsphere-proxy/shardingsphere-proxy-frontend/shardingsphere-proxy-frontend-postgresql/src/test/java/org/apache/shardingsphere/proxy/frontend/postgresql/command/query/extended/JDBCPortalTest.java
index 1efffa97a77..64e155f79d6 100644
--- a/shardingsphere-proxy/shardingsphere-proxy-frontend/shardingsphere-proxy-frontend-postgresql/src/test/java/org/apache/shardingsphere/proxy/frontend/postgresql/command/query/extended/JDBCPortalTest.java
+++ b/shardingsphere-proxy/shardingsphere-proxy-frontend/shardingsphere-proxy-frontend-postgresql/src/test/java/org/apache/shardingsphere/proxy/frontend/postgresql/command/query/extended/JDBCPortalTest.java
@@ -17,7 +17,6 @@
 
 package org.apache.shardingsphere.proxy.frontend.postgresql.command.query.extended;
 
-import lombok.SneakyThrows;
 import org.apache.shardingsphere.db.protocol.postgresql.constant.PostgreSQLValueFormat;
 import org.apache.shardingsphere.db.protocol.postgresql.packet.PostgreSQLPacket;
 import org.apache.shardingsphere.db.protocol.postgresql.packet.command.query.PostgreSQLDataRowPacket;
@@ -26,7 +25,14 @@ import org.apache.shardingsphere.db.protocol.postgresql.packet.command.query.Pos
 import org.apache.shardingsphere.db.protocol.postgresql.packet.command.query.PostgreSQLRowDescriptionPacket;
 import org.apache.shardingsphere.db.protocol.postgresql.packet.command.query.extended.execute.PostgreSQLPortalSuspendedPacket;
 import org.apache.shardingsphere.db.protocol.postgresql.packet.generic.PostgreSQLCommandCompletePacket;
+import org.apache.shardingsphere.db.protocol.postgresql.packet.handshake.PostgreSQLParameterStatusPacket;
+import org.apache.shardingsphere.infra.binder.LogicSQL;
+import org.apache.shardingsphere.infra.binder.statement.CommonSQLStatementContext;
+import org.apache.shardingsphere.infra.binder.statement.dml.InsertStatementContext;
+import org.apache.shardingsphere.infra.binder.statement.dml.SelectStatementContext;
 import org.apache.shardingsphere.infra.config.props.ConfigurationPropertyKey;
+import org.apache.shardingsphere.infra.database.type.dialect.PostgreSQLDatabaseType;
+import org.apache.shardingsphere.infra.metadata.database.ShardingSphereDatabase;
 import org.apache.shardingsphere.mode.manager.ContextManager;
 import org.apache.shardingsphere.proxy.backend.communication.jdbc.connection.JDBCBackendConnection;
 import org.apache.shardingsphere.proxy.backend.context.ProxyContext;
@@ -37,18 +43,24 @@ import org.apache.shardingsphere.proxy.backend.response.header.query.QueryRespon
 import org.apache.shardingsphere.proxy.backend.response.header.update.UpdateResponseHeader;
 import org.apache.shardingsphere.proxy.backend.session.ConnectionSession;
 import org.apache.shardingsphere.proxy.backend.text.TextProtocolBackendHandler;
+import org.apache.shardingsphere.proxy.backend.text.TextProtocolBackendHandlerFactory;
 import org.apache.shardingsphere.proxy.frontend.postgresql.ProxyContextRestorer;
+import org.apache.shardingsphere.sql.parser.sql.common.segment.dal.VariableAssignSegment;
+import org.apache.shardingsphere.sql.parser.sql.common.segment.dal.VariableSegment;
+import org.apache.shardingsphere.sql.parser.sql.common.statement.SQLStatement;
 import org.apache.shardingsphere.sql.parser.sql.common.statement.dml.EmptyStatement;
-import org.apache.shardingsphere.sql.parser.sql.common.statement.dml.InsertStatement;
-import org.apache.shardingsphere.sql.parser.sql.common.statement.dml.SelectStatement;
+import org.apache.shardingsphere.sql.parser.sql.dialect.statement.postgresql.dal.PostgreSQLSetStatement;
+import org.apache.shardingsphere.sql.parser.sql.dialect.statement.postgresql.dml.PostgreSQLInsertStatement;
+import org.apache.shardingsphere.sql.parser.sql.dialect.statement.postgresql.dml.PostgreSQLSelectStatement;
+import org.junit.After;
 import org.junit.Before;
 import org.junit.Test;
 import org.junit.runner.RunWith;
 import org.mockito.Answers;
 import org.mockito.Mock;
+import org.mockito.MockedStatic;
 import org.mockito.junit.MockitoJUnitRunner;
 
-import java.lang.reflect.Field;
 import java.sql.SQLException;
 import java.sql.Types;
 import java.util.ArrayList;
@@ -60,7 +72,13 @@ import java.util.List;
 import static org.hamcrest.CoreMatchers.instanceOf;
 import static org.hamcrest.CoreMatchers.is;
 import static org.junit.Assert.assertThat;
+import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.ArgumentMatchers.anyBoolean;
+import static org.mockito.ArgumentMatchers.anyString;
+import static org.mockito.ArgumentMatchers.eq;
+import static org.mockito.Mockito.RETURNS_DEEP_STUBS;
 import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.mockStatic;
 import static org.mockito.Mockito.verify;
 import static org.mockito.Mockito.when;
 
@@ -79,32 +97,38 @@ public final class JDBCPortalTest extends ProxyContextRestorer {
     @Mock
     private JDBCBackendConnection backendConnection;
     
-    private JDBCPortal portal;
+    private MockedStatic<TextProtocolBackendHandlerFactory> mockedStatic;
     
     @Before
     public void setup() throws SQLException {
         ProxyContext.init(mockContextManager);
+        when(mockContextManager.getMetaDataContexts().getMetaData().containsDatabase("db")).thenReturn(true);
         when(ProxyContext.getInstance().getContextManager().getMetaDataContexts().getMetaData().getProps().getValue(ConfigurationPropertyKey.SQL_SHOW)).thenReturn(false);
+        when(connectionSession.getDefaultDatabaseName()).thenReturn("db");
+        ShardingSphereDatabase database = mock(ShardingSphereDatabase.class, RETURNS_DEEP_STUBS);
+        when(database.getResource().getDatabaseType()).thenReturn(new PostgreSQLDatabaseType());
+        when(ProxyContext.getInstance().getDatabase("db")).thenReturn(database);
         when(backendConnection.getConnectionSession()).thenReturn(connectionSession);
-        prepareJDBCPortal();
+        mockedStatic = mockStatic(TextProtocolBackendHandlerFactory.class);
+        mockedStatic.when(() -> TextProtocolBackendHandlerFactory.newInstance(any(PostgreSQLDatabaseType.class), anyString(), any(SQLStatement.class), eq(connectionSession)))
+                .thenReturn(textProtocolBackendHandler);
+        mockedStatic.when(() -> TextProtocolBackendHandlerFactory.newInstance(any(PostgreSQLDatabaseType.class), any(LogicSQL.class), eq(connectionSession), anyBoolean()))
+                .thenReturn(textProtocolBackendHandler);
     }
     
-    private void prepareJDBCPortal() throws SQLException {
-        PostgreSQLPreparedStatement preparedStatement = mock(PostgreSQLPreparedStatement.class);
-        when(preparedStatement.getSql()).thenReturn("");
-        when(preparedStatement.getSqlStatement()).thenReturn(new EmptyStatement());
-        List<PostgreSQLValueFormat> resultFormats = new ArrayList<>(Arrays.asList(PostgreSQLValueFormat.TEXT, PostgreSQLValueFormat.BINARY));
-        portal = new JDBCPortal("", preparedStatement, Collections.emptyList(), resultFormats, backendConnection);
-        setField(portal, "textProtocolBackendHandler", textProtocolBackendHandler);
+    @After
+    public void tearDown() {
+        mockedStatic.close();
     }
     
     @Test
-    public void assertGetName() {
+    public void assertGetName() throws SQLException {
+        JDBCPortal portal = new JDBCPortal("", new PostgreSQLPreparedStatement("", null, null, Collections.emptyList()), Collections.emptyList(), Collections.emptyList(), backendConnection);
         assertThat(portal.getName(), is(""));
     }
     
     @Test
-    public void assertExecuteSelectStatementWithDatabaseCommunicationEngineAndReturnAllRows() throws SQLException {
+    public void assertExecuteSelectStatementAndReturnAllRows() throws SQLException {
         QueryResponseHeader responseHeader = mock(QueryResponseHeader.class);
         QueryHeader queryHeader = new QueryHeader("schema", "table", "columnLabel", "columnName", Types.INTEGER, "columnTypeName", 0, 0, false, false, false, false);
         when(responseHeader.getQueryHeaders()).thenReturn(Collections.singletonList(queryHeader));
@@ -112,9 +136,11 @@ public final class JDBCPortalTest extends ProxyContextRestorer {
         when(textProtocolBackendHandler.next()).thenReturn(true, true, false);
         when(textProtocolBackendHandler.getRowData()).thenReturn(new QueryResponseRow(Collections.singletonList(new QueryResponseCell(Types.INTEGER, 0))),
                 new QueryResponseRow(Collections.singletonList(new QueryResponseCell(Types.INTEGER, 1))));
+        PostgreSQLPreparedStatement preparedStatement = new PostgreSQLPreparedStatement("", new PostgreSQLSelectStatement(), mock(SelectStatementContext.class), Collections.emptyList());
+        List<PostgreSQLValueFormat> resultFormats = new ArrayList<>(Arrays.asList(PostgreSQLValueFormat.TEXT, PostgreSQLValueFormat.BINARY));
+        JDBCPortal portal = new JDBCPortal("", preparedStatement, Collections.emptyList(), resultFormats, backendConnection);
         portal.bind();
         assertThat(portal.describe(), instanceOf(PostgreSQLRowDescriptionPacket.class));
-        setField(portal, "sqlStatement", mock(SelectStatement.class));
         List<PostgreSQLPacket> actualPackets = portal.execute(0);
         assertThat(actualPackets.size(), is(3));
         Iterator<PostgreSQLPacket> actualPacketsIterator = actualPackets.iterator();
@@ -124,7 +150,7 @@ public final class JDBCPortalTest extends ProxyContextRestorer {
     }
     
     @Test
-    public void assertExecuteSelectStatementWithDatabaseCommunicationEngineAndPortalSuspended() throws SQLException {
+    public void assertExecuteSelectStatementAndPortalSuspended() throws SQLException {
         QueryResponseHeader responseHeader = mock(QueryResponseHeader.class);
         QueryHeader queryHeader = new QueryHeader("schema", "table", "columnLabel", "columnName", Types.INTEGER, "columnTypeName", 0, 0, false, false, false, false);
         when(responseHeader.getQueryHeaders()).thenReturn(Collections.singletonList(queryHeader));
@@ -133,10 +159,11 @@ public final class JDBCPortalTest extends ProxyContextRestorer {
         when(textProtocolBackendHandler.getRowData()).thenReturn(
                 new QueryResponseRow(Collections.singletonList(new QueryResponseCell(Types.INTEGER, 0))),
                 new QueryResponseRow(Collections.singletonList(new QueryResponseCell(Types.INTEGER, 1))));
-        setField(portal, "resultFormats", Collections.singletonList(PostgreSQLValueFormat.BINARY));
+        PostgreSQLPreparedStatement preparedStatement = new PostgreSQLPreparedStatement("", new PostgreSQLSelectStatement(), mock(SelectStatementContext.class), Collections.emptyList());
+        List<PostgreSQLValueFormat> resultFormats = new ArrayList<>(Arrays.asList(PostgreSQLValueFormat.TEXT, PostgreSQLValueFormat.BINARY));
+        JDBCPortal portal = new JDBCPortal("", preparedStatement, Collections.emptyList(), resultFormats, backendConnection);
         portal.bind();
         assertThat(portal.describe(), instanceOf(PostgreSQLRowDescriptionPacket.class));
-        setField(portal, "sqlStatement", mock(SelectStatement.class));
         List<PostgreSQLPacket> actualPackets = portal.execute(2);
         assertThat(actualPackets.size(), is(3));
         Iterator<PostgreSQLPacket> actualPacketsIterator = actualPackets.iterator();
@@ -146,26 +173,47 @@ public final class JDBCPortalTest extends ProxyContextRestorer {
     }
     
     @Test
-    public void assertExecuteUpdateWithDatabaseCommunicationEngine() throws SQLException {
+    public void assertExecuteUpdate() throws SQLException {
         when(textProtocolBackendHandler.execute()).thenReturn(mock(UpdateResponseHeader.class));
         when(textProtocolBackendHandler.next()).thenReturn(false);
+        PostgreSQLPreparedStatement preparedStatement = new PostgreSQLPreparedStatement("", new PostgreSQLInsertStatement(), mock(InsertStatementContext.class), Collections.emptyList());
+        JDBCPortal portal = new JDBCPortal("insert into t values (1)", preparedStatement, Collections.emptyList(), Collections.emptyList(), backendConnection);
         portal.bind();
         assertThat(portal.describe(), is(PostgreSQLNoDataPacket.getInstance()));
-        setField(portal, "sqlStatement", mock(InsertStatement.class));
         List<PostgreSQLPacket> actualPackets = portal.execute(0);
         assertThat(actualPackets.iterator().next(), instanceOf(PostgreSQLCommandCompletePacket.class));
     }
     
     @Test
-    public void assertExecuteEmptyStatementWithDatabaseCommunicationEngine() throws SQLException {
+    public void assertExecuteEmptyStatement() throws SQLException {
         when(textProtocolBackendHandler.execute()).thenReturn(mock(UpdateResponseHeader.class));
         when(textProtocolBackendHandler.next()).thenReturn(false);
+        PostgreSQLPreparedStatement preparedStatement = new PostgreSQLPreparedStatement("", new EmptyStatement(), null, Collections.emptyList());
+        JDBCPortal portal = new JDBCPortal("", preparedStatement, Collections.emptyList(), Collections.emptyList(), backendConnection);
         portal.bind();
         assertThat(portal.describe(), is(PostgreSQLNoDataPacket.getInstance()));
         List<PostgreSQLPacket> actualPackets = portal.execute(0);
         assertThat(actualPackets.iterator().next(), instanceOf(PostgreSQLEmptyQueryResponsePacket.class));
     }
     
+    @Test
+    public void assertExecuteSetStatement() throws SQLException {
+        when(textProtocolBackendHandler.execute()).thenReturn(mock(UpdateResponseHeader.class));
+        when(textProtocolBackendHandler.next()).thenReturn(false);
+        String sql = "set client_encoding = utf8";
+        PostgreSQLSetStatement setStatement = new PostgreSQLSetStatement();
+        VariableAssignSegment variableAssignSegment = new VariableAssignSegment();
+        variableAssignSegment.setVariable(new VariableSegment());
+        setStatement.getVariableAssigns().add(variableAssignSegment);
+        PostgreSQLPreparedStatement preparedStatement = new PostgreSQLPreparedStatement(sql, setStatement, new CommonSQLStatementContext<>(setStatement), Collections.emptyList());
+        JDBCPortal portal = new JDBCPortal("", preparedStatement, Collections.emptyList(), Collections.emptyList(), backendConnection);
+        portal.bind();
+        List<PostgreSQLPacket> actualPackets = portal.execute(0);
+        assertThat(actualPackets.size(), is(2));
+        assertThat(actualPackets.get(0), instanceOf(PostgreSQLCommandCompletePacket.class));
+        assertThat(actualPackets.get(1), instanceOf(PostgreSQLParameterStatusPacket.class));
+    }
+    
     @Test(expected = IllegalStateException.class)
     public void assertDescribeBeforeBind() throws SQLException {
         PostgreSQLPreparedStatement preparedStatement = mock(PostgreSQLPreparedStatement.class);
@@ -176,15 +224,10 @@ public final class JDBCPortalTest extends ProxyContextRestorer {
     
     @Test
     public void assertClose() throws SQLException {
+        PostgreSQLPreparedStatement preparedStatement = new PostgreSQLPreparedStatement("", new EmptyStatement(), null, Collections.emptyList());
+        JDBCPortal portal = new JDBCPortal("", preparedStatement, Collections.emptyList(), Collections.emptyList(), backendConnection);
         portal.close();
         verify(backendConnection).unmarkResourceInUse(textProtocolBackendHandler);
         verify(textProtocolBackendHandler).close();
     }
-    
-    @SneakyThrows
-    private void setField(final JDBCPortal portal, final String fieldName, final Object value) {
-        Field field = JDBCPortal.class.getDeclaredField(fieldName);
-        field.setAccessible(true);
-        field.set(portal, value);
-    }
 }