You are viewing a plain text version of this content. The canonical link for it is here.
Posted to notifications@shardingsphere.apache.org by pa...@apache.org on 2022/07/28 02:05:15 UTC
[shardingsphere] branch master updated: Reuse SQLStatementContext in JDBCPortal (#19609)
This is an automated email from the ASF dual-hosted git repository.
panjuan pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/shardingsphere.git
The following commit(s) were added to refs/heads/master by this push:
new 328510046db Reuse SQLStatementContext in JDBCPortal (#19609)
328510046db is described below
commit 328510046db5d84d7971fb94c3a01b3e3e4e2ad7
Author: 吴伟杰 <wu...@apache.org>
AuthorDate: Thu Jul 28 10:05:09 2022 +0800
Reuse SQLStatementContext in JDBCPortal (#19609)
---
.../command/query/extended/JDBCPortal.java | 11 ++-
.../command/query/extended/JDBCPortalTest.java | 101 +++++++++++++++------
2 files changed, 78 insertions(+), 34 deletions(-)
diff --git a/shardingsphere-proxy/shardingsphere-proxy-frontend/shardingsphere-proxy-frontend-postgresql/src/main/java/org/apache/shardingsphere/proxy/frontend/postgresql/command/query/extended/JDBCPortal.java b/shardingsphere-proxy/shardingsphere-proxy-frontend/shardingsphere-proxy-frontend-postgresql/src/main/java/org/apache/shardingsphere/proxy/frontend/postgresql/command/query/extended/JDBCPortal.java
index b5e8fb0310d..e70292db9bd 100644
--- a/shardingsphere-proxy/shardingsphere-proxy-frontend/shardingsphere-proxy-frontend-postgresql/src/main/java/org/apache/shardingsphere/proxy/frontend/postgresql/command/query/extended/JDBCPortal.java
+++ b/shardingsphere-proxy/shardingsphere-proxy-frontend/shardingsphere-proxy-frontend-postgresql/src/main/java/org/apache/shardingsphere/proxy/frontend/postgresql/command/query/extended/JDBCPortal.java
@@ -32,9 +32,8 @@ import org.apache.shardingsphere.db.protocol.postgresql.packet.command.query.ext
import org.apache.shardingsphere.db.protocol.postgresql.packet.generic.PostgreSQLCommandCompletePacket;
import org.apache.shardingsphere.db.protocol.postgresql.packet.handshake.PostgreSQLParameterStatusPacket;
import org.apache.shardingsphere.db.protocol.postgresql.packet.identifier.PostgreSQLIdentifierPacket;
-import org.apache.shardingsphere.distsql.parser.statement.DistSQLStatement;
import org.apache.shardingsphere.infra.binder.LogicSQL;
-import org.apache.shardingsphere.infra.binder.SQLStatementContextFactory;
+import org.apache.shardingsphere.infra.binder.aware.ParameterAware;
import org.apache.shardingsphere.infra.binder.statement.SQLStatementContext;
import org.apache.shardingsphere.infra.database.type.DatabaseType;
import org.apache.shardingsphere.infra.database.type.DatabaseTypeFactory;
@@ -87,14 +86,16 @@ public final class JDBCPortal implements Portal<Void> {
this.sqlStatement = preparedStatement.getSqlStatement();
this.resultFormats = resultFormats;
this.backendConnection = backendConnection;
- if (sqlStatement instanceof EmptyStatement || sqlStatement instanceof DistSQLStatement) {
+ if (!preparedStatement.getSqlStatementContext().isPresent()) {
textProtocolBackendHandler = TextProtocolBackendHandlerFactory.newInstance(DatabaseTypeFactory.getInstance("PostgreSQL"),
preparedStatement.getSql(), sqlStatement, backendConnection.getConnectionSession());
return;
}
String databaseName = backendConnection.getConnectionSession().getDefaultDatabaseName();
- SQLStatementContext<?> sqlStatementContext = SQLStatementContextFactory.newInstance(
- ProxyContext.getInstance().getContextManager().getMetaDataContexts().getMetaData().getDatabases(), parameters, sqlStatement, databaseName);
+ SQLStatementContext<?> sqlStatementContext = preparedStatement.getSqlStatementContext().get();
+ if (sqlStatementContext instanceof ParameterAware) {
+ ((ParameterAware) sqlStatementContext).setUpParameters(parameters);
+ }
DatabaseType databaseType = getDatabaseType(databaseName);
LogicSQL logicSQL = new LogicSQL(sqlStatementContext, preparedStatement.getSql(), parameters);
textProtocolBackendHandler = TextProtocolBackendHandlerFactory.newInstance(databaseType, logicSQL, backendConnection.getConnectionSession(), true);
diff --git a/shardingsphere-proxy/shardingsphere-proxy-frontend/shardingsphere-proxy-frontend-postgresql/src/test/java/org/apache/shardingsphere/proxy/frontend/postgresql/command/query/extended/JDBCPortalTest.java b/shardingsphere-proxy/shardingsphere-proxy-frontend/shardingsphere-proxy-frontend-postgresql/src/test/java/org/apache/shardingsphere/proxy/frontend/postgresql/command/query/extended/JDBCPortalTest.java
index 1efffa97a77..64e155f79d6 100644
--- a/shardingsphere-proxy/shardingsphere-proxy-frontend/shardingsphere-proxy-frontend-postgresql/src/test/java/org/apache/shardingsphere/proxy/frontend/postgresql/command/query/extended/JDBCPortalTest.java
+++ b/shardingsphere-proxy/shardingsphere-proxy-frontend/shardingsphere-proxy-frontend-postgresql/src/test/java/org/apache/shardingsphere/proxy/frontend/postgresql/command/query/extended/JDBCPortalTest.java
@@ -17,7 +17,6 @@
package org.apache.shardingsphere.proxy.frontend.postgresql.command.query.extended;
-import lombok.SneakyThrows;
import org.apache.shardingsphere.db.protocol.postgresql.constant.PostgreSQLValueFormat;
import org.apache.shardingsphere.db.protocol.postgresql.packet.PostgreSQLPacket;
import org.apache.shardingsphere.db.protocol.postgresql.packet.command.query.PostgreSQLDataRowPacket;
@@ -26,7 +25,14 @@ import org.apache.shardingsphere.db.protocol.postgresql.packet.command.query.Pos
import org.apache.shardingsphere.db.protocol.postgresql.packet.command.query.PostgreSQLRowDescriptionPacket;
import org.apache.shardingsphere.db.protocol.postgresql.packet.command.query.extended.execute.PostgreSQLPortalSuspendedPacket;
import org.apache.shardingsphere.db.protocol.postgresql.packet.generic.PostgreSQLCommandCompletePacket;
+import org.apache.shardingsphere.db.protocol.postgresql.packet.handshake.PostgreSQLParameterStatusPacket;
+import org.apache.shardingsphere.infra.binder.LogicSQL;
+import org.apache.shardingsphere.infra.binder.statement.CommonSQLStatementContext;
+import org.apache.shardingsphere.infra.binder.statement.dml.InsertStatementContext;
+import org.apache.shardingsphere.infra.binder.statement.dml.SelectStatementContext;
import org.apache.shardingsphere.infra.config.props.ConfigurationPropertyKey;
+import org.apache.shardingsphere.infra.database.type.dialect.PostgreSQLDatabaseType;
+import org.apache.shardingsphere.infra.metadata.database.ShardingSphereDatabase;
import org.apache.shardingsphere.mode.manager.ContextManager;
import org.apache.shardingsphere.proxy.backend.communication.jdbc.connection.JDBCBackendConnection;
import org.apache.shardingsphere.proxy.backend.context.ProxyContext;
@@ -37,18 +43,24 @@ import org.apache.shardingsphere.proxy.backend.response.header.query.QueryRespon
import org.apache.shardingsphere.proxy.backend.response.header.update.UpdateResponseHeader;
import org.apache.shardingsphere.proxy.backend.session.ConnectionSession;
import org.apache.shardingsphere.proxy.backend.text.TextProtocolBackendHandler;
+import org.apache.shardingsphere.proxy.backend.text.TextProtocolBackendHandlerFactory;
import org.apache.shardingsphere.proxy.frontend.postgresql.ProxyContextRestorer;
+import org.apache.shardingsphere.sql.parser.sql.common.segment.dal.VariableAssignSegment;
+import org.apache.shardingsphere.sql.parser.sql.common.segment.dal.VariableSegment;
+import org.apache.shardingsphere.sql.parser.sql.common.statement.SQLStatement;
import org.apache.shardingsphere.sql.parser.sql.common.statement.dml.EmptyStatement;
-import org.apache.shardingsphere.sql.parser.sql.common.statement.dml.InsertStatement;
-import org.apache.shardingsphere.sql.parser.sql.common.statement.dml.SelectStatement;
+import org.apache.shardingsphere.sql.parser.sql.dialect.statement.postgresql.dal.PostgreSQLSetStatement;
+import org.apache.shardingsphere.sql.parser.sql.dialect.statement.postgresql.dml.PostgreSQLInsertStatement;
+import org.apache.shardingsphere.sql.parser.sql.dialect.statement.postgresql.dml.PostgreSQLSelectStatement;
+import org.junit.After;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.mockito.Answers;
import org.mockito.Mock;
+import org.mockito.MockedStatic;
import org.mockito.junit.MockitoJUnitRunner;
-import java.lang.reflect.Field;
import java.sql.SQLException;
import java.sql.Types;
import java.util.ArrayList;
@@ -60,7 +72,13 @@ import java.util.List;
import static org.hamcrest.CoreMatchers.instanceOf;
import static org.hamcrest.CoreMatchers.is;
import static org.junit.Assert.assertThat;
+import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.ArgumentMatchers.anyBoolean;
+import static org.mockito.ArgumentMatchers.anyString;
+import static org.mockito.ArgumentMatchers.eq;
+import static org.mockito.Mockito.RETURNS_DEEP_STUBS;
import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.mockStatic;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
@@ -79,32 +97,38 @@ public final class JDBCPortalTest extends ProxyContextRestorer {
@Mock
private JDBCBackendConnection backendConnection;
- private JDBCPortal portal;
+ private MockedStatic<TextProtocolBackendHandlerFactory> mockedStatic;
@Before
public void setup() throws SQLException {
ProxyContext.init(mockContextManager);
+ when(mockContextManager.getMetaDataContexts().getMetaData().containsDatabase("db")).thenReturn(true);
when(ProxyContext.getInstance().getContextManager().getMetaDataContexts().getMetaData().getProps().getValue(ConfigurationPropertyKey.SQL_SHOW)).thenReturn(false);
+ when(connectionSession.getDefaultDatabaseName()).thenReturn("db");
+ ShardingSphereDatabase database = mock(ShardingSphereDatabase.class, RETURNS_DEEP_STUBS);
+ when(database.getResource().getDatabaseType()).thenReturn(new PostgreSQLDatabaseType());
+ when(ProxyContext.getInstance().getDatabase("db")).thenReturn(database);
when(backendConnection.getConnectionSession()).thenReturn(connectionSession);
- prepareJDBCPortal();
+ mockedStatic = mockStatic(TextProtocolBackendHandlerFactory.class);
+ mockedStatic.when(() -> TextProtocolBackendHandlerFactory.newInstance(any(PostgreSQLDatabaseType.class), anyString(), any(SQLStatement.class), eq(connectionSession)))
+ .thenReturn(textProtocolBackendHandler);
+ mockedStatic.when(() -> TextProtocolBackendHandlerFactory.newInstance(any(PostgreSQLDatabaseType.class), any(LogicSQL.class), eq(connectionSession), anyBoolean()))
+ .thenReturn(textProtocolBackendHandler);
}
- private void prepareJDBCPortal() throws SQLException {
- PostgreSQLPreparedStatement preparedStatement = mock(PostgreSQLPreparedStatement.class);
- when(preparedStatement.getSql()).thenReturn("");
- when(preparedStatement.getSqlStatement()).thenReturn(new EmptyStatement());
- List<PostgreSQLValueFormat> resultFormats = new ArrayList<>(Arrays.asList(PostgreSQLValueFormat.TEXT, PostgreSQLValueFormat.BINARY));
- portal = new JDBCPortal("", preparedStatement, Collections.emptyList(), resultFormats, backendConnection);
- setField(portal, "textProtocolBackendHandler", textProtocolBackendHandler);
+ @After
+ public void tearDown() {
+ mockedStatic.close();
}
@Test
- public void assertGetName() {
+ public void assertGetName() throws SQLException {
+ JDBCPortal portal = new JDBCPortal("", new PostgreSQLPreparedStatement("", null, null, Collections.emptyList()), Collections.emptyList(), Collections.emptyList(), backendConnection);
assertThat(portal.getName(), is(""));
}
@Test
- public void assertExecuteSelectStatementWithDatabaseCommunicationEngineAndReturnAllRows() throws SQLException {
+ public void assertExecuteSelectStatementAndReturnAllRows() throws SQLException {
QueryResponseHeader responseHeader = mock(QueryResponseHeader.class);
QueryHeader queryHeader = new QueryHeader("schema", "table", "columnLabel", "columnName", Types.INTEGER, "columnTypeName", 0, 0, false, false, false, false);
when(responseHeader.getQueryHeaders()).thenReturn(Collections.singletonList(queryHeader));
@@ -112,9 +136,11 @@ public final class JDBCPortalTest extends ProxyContextRestorer {
when(textProtocolBackendHandler.next()).thenReturn(true, true, false);
when(textProtocolBackendHandler.getRowData()).thenReturn(new QueryResponseRow(Collections.singletonList(new QueryResponseCell(Types.INTEGER, 0))),
new QueryResponseRow(Collections.singletonList(new QueryResponseCell(Types.INTEGER, 1))));
+ PostgreSQLPreparedStatement preparedStatement = new PostgreSQLPreparedStatement("", new PostgreSQLSelectStatement(), mock(SelectStatementContext.class), Collections.emptyList());
+ List<PostgreSQLValueFormat> resultFormats = new ArrayList<>(Arrays.asList(PostgreSQLValueFormat.TEXT, PostgreSQLValueFormat.BINARY));
+ JDBCPortal portal = new JDBCPortal("", preparedStatement, Collections.emptyList(), resultFormats, backendConnection);
portal.bind();
assertThat(portal.describe(), instanceOf(PostgreSQLRowDescriptionPacket.class));
- setField(portal, "sqlStatement", mock(SelectStatement.class));
List<PostgreSQLPacket> actualPackets = portal.execute(0);
assertThat(actualPackets.size(), is(3));
Iterator<PostgreSQLPacket> actualPacketsIterator = actualPackets.iterator();
@@ -124,7 +150,7 @@ public final class JDBCPortalTest extends ProxyContextRestorer {
}
@Test
- public void assertExecuteSelectStatementWithDatabaseCommunicationEngineAndPortalSuspended() throws SQLException {
+ public void assertExecuteSelectStatementAndPortalSuspended() throws SQLException {
QueryResponseHeader responseHeader = mock(QueryResponseHeader.class);
QueryHeader queryHeader = new QueryHeader("schema", "table", "columnLabel", "columnName", Types.INTEGER, "columnTypeName", 0, 0, false, false, false, false);
when(responseHeader.getQueryHeaders()).thenReturn(Collections.singletonList(queryHeader));
@@ -133,10 +159,11 @@ public final class JDBCPortalTest extends ProxyContextRestorer {
when(textProtocolBackendHandler.getRowData()).thenReturn(
new QueryResponseRow(Collections.singletonList(new QueryResponseCell(Types.INTEGER, 0))),
new QueryResponseRow(Collections.singletonList(new QueryResponseCell(Types.INTEGER, 1))));
- setField(portal, "resultFormats", Collections.singletonList(PostgreSQLValueFormat.BINARY));
+ PostgreSQLPreparedStatement preparedStatement = new PostgreSQLPreparedStatement("", new PostgreSQLSelectStatement(), mock(SelectStatementContext.class), Collections.emptyList());
+ List<PostgreSQLValueFormat> resultFormats = new ArrayList<>(Arrays.asList(PostgreSQLValueFormat.TEXT, PostgreSQLValueFormat.BINARY));
+ JDBCPortal portal = new JDBCPortal("", preparedStatement, Collections.emptyList(), resultFormats, backendConnection);
portal.bind();
assertThat(portal.describe(), instanceOf(PostgreSQLRowDescriptionPacket.class));
- setField(portal, "sqlStatement", mock(SelectStatement.class));
List<PostgreSQLPacket> actualPackets = portal.execute(2);
assertThat(actualPackets.size(), is(3));
Iterator<PostgreSQLPacket> actualPacketsIterator = actualPackets.iterator();
@@ -146,26 +173,47 @@ public final class JDBCPortalTest extends ProxyContextRestorer {
}
@Test
- public void assertExecuteUpdateWithDatabaseCommunicationEngine() throws SQLException {
+ public void assertExecuteUpdate() throws SQLException {
when(textProtocolBackendHandler.execute()).thenReturn(mock(UpdateResponseHeader.class));
when(textProtocolBackendHandler.next()).thenReturn(false);
+ PostgreSQLPreparedStatement preparedStatement = new PostgreSQLPreparedStatement("", new PostgreSQLInsertStatement(), mock(InsertStatementContext.class), Collections.emptyList());
+ JDBCPortal portal = new JDBCPortal("insert into t values (1)", preparedStatement, Collections.emptyList(), Collections.emptyList(), backendConnection);
portal.bind();
assertThat(portal.describe(), is(PostgreSQLNoDataPacket.getInstance()));
- setField(portal, "sqlStatement", mock(InsertStatement.class));
List<PostgreSQLPacket> actualPackets = portal.execute(0);
assertThat(actualPackets.iterator().next(), instanceOf(PostgreSQLCommandCompletePacket.class));
}
@Test
- public void assertExecuteEmptyStatementWithDatabaseCommunicationEngine() throws SQLException {
+ public void assertExecuteEmptyStatement() throws SQLException {
when(textProtocolBackendHandler.execute()).thenReturn(mock(UpdateResponseHeader.class));
when(textProtocolBackendHandler.next()).thenReturn(false);
+ PostgreSQLPreparedStatement preparedStatement = new PostgreSQLPreparedStatement("", new EmptyStatement(), null, Collections.emptyList());
+ JDBCPortal portal = new JDBCPortal("", preparedStatement, Collections.emptyList(), Collections.emptyList(), backendConnection);
portal.bind();
assertThat(portal.describe(), is(PostgreSQLNoDataPacket.getInstance()));
List<PostgreSQLPacket> actualPackets = portal.execute(0);
assertThat(actualPackets.iterator().next(), instanceOf(PostgreSQLEmptyQueryResponsePacket.class));
}
+ @Test
+ public void assertExecuteSetStatement() throws SQLException {
+ when(textProtocolBackendHandler.execute()).thenReturn(mock(UpdateResponseHeader.class));
+ when(textProtocolBackendHandler.next()).thenReturn(false);
+ String sql = "set client_encoding = utf8";
+ PostgreSQLSetStatement setStatement = new PostgreSQLSetStatement();
+ VariableAssignSegment variableAssignSegment = new VariableAssignSegment();
+ variableAssignSegment.setVariable(new VariableSegment());
+ setStatement.getVariableAssigns().add(variableAssignSegment);
+ PostgreSQLPreparedStatement preparedStatement = new PostgreSQLPreparedStatement(sql, setStatement, new CommonSQLStatementContext<>(setStatement), Collections.emptyList());
+ JDBCPortal portal = new JDBCPortal("", preparedStatement, Collections.emptyList(), Collections.emptyList(), backendConnection);
+ portal.bind();
+ List<PostgreSQLPacket> actualPackets = portal.execute(0);
+ assertThat(actualPackets.size(), is(2));
+ assertThat(actualPackets.get(0), instanceOf(PostgreSQLCommandCompletePacket.class));
+ assertThat(actualPackets.get(1), instanceOf(PostgreSQLParameterStatusPacket.class));
+ }
+
@Test(expected = IllegalStateException.class)
public void assertDescribeBeforeBind() throws SQLException {
PostgreSQLPreparedStatement preparedStatement = mock(PostgreSQLPreparedStatement.class);
@@ -176,15 +224,10 @@ public final class JDBCPortalTest extends ProxyContextRestorer {
@Test
public void assertClose() throws SQLException {
+ PostgreSQLPreparedStatement preparedStatement = new PostgreSQLPreparedStatement("", new EmptyStatement(), null, Collections.emptyList());
+ JDBCPortal portal = new JDBCPortal("", preparedStatement, Collections.emptyList(), Collections.emptyList(), backendConnection);
portal.close();
verify(backendConnection).unmarkResourceInUse(textProtocolBackendHandler);
verify(textProtocolBackendHandler).close();
}
-
- @SneakyThrows
- private void setField(final JDBCPortal portal, final String fieldName, final Object value) {
- Field field = JDBCPortal.class.getDeclaredField(fieldName);
- field.setAccessible(true);
- field.set(portal, value);
- }
}