You are viewing a plain text version of this content. The canonical link for it is here.
Posted to notifications@shardingsphere.apache.org by pa...@apache.org on 2022/06/14 09:06:26 UTC
[shardingsphere] branch master updated: Move unchanged SQLStatement into MySQLPreparedStatement (#18357)
This is an automated email from the ASF dual-hosted git repository.
panjuan pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/shardingsphere.git
The following commit(s) were added to refs/heads/master by this push:
new 7ef928755d9 Move unchanged SQLStatement into MySQLPreparedStatement (#18357)
7ef928755d9 is described below
commit 7ef928755d92b365fb83492f9e0eb14320b118cb
Author: 吴伟杰 <wu...@apache.org>
AuthorDate: Tue Jun 14 17:06:20 2022 +0800
Move unchanged SQLStatement into MySQLPreparedStatement (#18357)
---
.../query/binary/MySQLPreparedStatement.java | 3 ++-
.../binary/MySQLPreparedStatementRegistry.java | 7 ++++---
.../binary/execute/MySQLComStmtExecutePacket.java | 5 +++--
.../command/MySQLMySQLCommandPacketFactoryTest.java | 3 ++-
.../binary/MySQLPreparedStatementRegistryTest.java | 21 ++++++++++++++-------
.../execute/MySQLComStmtExecutePacketTest.java | 5 ++++-
.../binary/execute/MySQLComStmtExecuteExecutor.java | 8 +-------
.../binary/prepare/MySQLComStmtPrepareExecutor.java | 5 ++---
.../command/MySQLCommandExecutorFactoryTest.java | 7 ++++++-
.../execute/MySQLComStmtExecuteExecutorTest.java | 19 +++++++++++++++----
10 files changed, 53 insertions(+), 30 deletions(-)
diff --git a/shardingsphere-db-protocol/shardingsphere-db-protocol-mysql/src/main/java/org/apache/shardingsphere/db/protocol/mysql/packet/command/query/binary/MySQLPreparedStatement.java b/shardingsphere-db-protocol/shardingsphere-db-protocol-mysql/src/main/java/org/apache/shardingsphere/db/protocol/mysql/packet/command/query/binary/MySQLPreparedStatement.java
index 1a03fd3ec80..edf282e8572 100644
--- a/shardingsphere-db-protocol/shardingsphere-db-protocol-mysql/src/main/java/org/apache/shardingsphere/db/protocol/mysql/packet/command/query/binary/MySQLPreparedStatement.java
+++ b/shardingsphere-db-protocol/shardingsphere-db-protocol-mysql/src/main/java/org/apache/shardingsphere/db/protocol/mysql/packet/command/query/binary/MySQLPreparedStatement.java
@@ -20,6 +20,7 @@ package org.apache.shardingsphere.db.protocol.mysql.packet.command.query.binary;
import lombok.Getter;
import lombok.RequiredArgsConstructor;
import lombok.Setter;
+import org.apache.shardingsphere.sql.parser.sql.common.statement.SQLStatement;
import java.util.List;
@@ -33,7 +34,7 @@ public final class MySQLPreparedStatement {
private final String sql;
- private final int parameterCount;
+ private final SQLStatement sqlStatement;
private List<MySQLPreparedStatementParameterType> parameterTypes;
}
diff --git a/shardingsphere-db-protocol/shardingsphere-db-protocol-mysql/src/main/java/org/apache/shardingsphere/db/protocol/mysql/packet/command/query/binary/MySQLPreparedStatementRegistry.java b/shardingsphere-db-protocol/shardingsphere-db-protocol-mysql/src/main/java/org/apache/shardingsphere/db/protocol/mysql/packet/command/query/binary/MySQLPreparedStatementRegistry.java
index e2eb11db478..273e56dfccf 100644
--- a/shardingsphere-db-protocol/shardingsphere-db-protocol-mysql/src/main/java/org/apache/shardingsphere/db/protocol/mysql/packet/command/query/binary/MySQLPreparedStatementRegistry.java
+++ b/shardingsphere-db-protocol/shardingsphere-db-protocol-mysql/src/main/java/org/apache/shardingsphere/db/protocol/mysql/packet/command/query/binary/MySQLPreparedStatementRegistry.java
@@ -19,6 +19,7 @@ package org.apache.shardingsphere.db.protocol.mysql.packet.command.query.binary;
import lombok.AccessLevel;
import lombok.NoArgsConstructor;
+import org.apache.shardingsphere.sql.parser.sql.common.statement.SQLStatement;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
@@ -82,12 +83,12 @@ public final class MySQLPreparedStatementRegistry {
* Prepare statement.
*
* @param sql SQL
- * @param parameterCount parameter count
+ * @param sqlStatement sql statement of prepared statement
* @return statement ID
*/
- public int prepareStatement(final String sql, final int parameterCount) {
+ public int prepareStatement(final String sql, final SQLStatement sqlStatement) {
int result = sequence.incrementAndGet();
- preparedStatements.put(result, new MySQLPreparedStatement(sql, parameterCount));
+ preparedStatements.put(result, new MySQLPreparedStatement(sql, sqlStatement));
return result;
}
diff --git a/shardingsphere-db-protocol/shardingsphere-db-protocol-mysql/src/main/java/org/apache/shardingsphere/db/protocol/mysql/packet/command/query/binary/execute/MySQLComStmtExecutePacket.java b/shardingsphere-db-protocol/shardingsphere-db-protocol-mysql/src/main/java/org/apache/shardingsphere/db/protocol/mysql/packet/command/query/binary/execute/MySQLComStmtExecutePacket.java
index 4e0b399192a..0de41c2e8e1 100644
--- a/shardingsphere-db-protocol/shardingsphere-db-protocol-mysql/src/main/java/org/apache/shardingsphere/db/protocol/mysql/packet/command/query/binary/execute/MySQLComStmtExecutePacket.java
+++ b/shardingsphere-db-protocol/shardingsphere-db-protocol-mysql/src/main/java/org/apache/shardingsphere/db/protocol/mysql/packet/command/query/binary/execute/MySQLComStmtExecutePacket.java
@@ -50,6 +50,7 @@ public final class MySQLComStmtExecutePacket extends MySQLCommandPacket {
private final int statementId;
+ @Getter
private final MySQLPreparedStatement preparedStatement;
private final int flags;
@@ -70,7 +71,7 @@ public final class MySQLComStmtExecutePacket extends MySQLCommandPacket {
preparedStatement = MySQLPreparedStatementRegistry.getInstance().getConnectionPreparedStatements(connectionId).get(statementId);
flags = payload.readInt1();
Preconditions.checkArgument(ITERATION_COUNT == payload.readInt4());
- int parameterCount = preparedStatement.getParameterCount();
+ int parameterCount = preparedStatement.getSqlStatement().getParameterCount();
sql = preparedStatement.getSql();
if (parameterCount > 0) {
nullBitmap = new MySQLNullBitmap(parameterCount, NULL_BITMAP_OFFSET);
@@ -113,7 +114,7 @@ public final class MySQLComStmtExecutePacket extends MySQLCommandPacket {
payload.writeInt4(statementId);
payload.writeInt1(flags);
payload.writeInt4(ITERATION_COUNT);
- if (preparedStatement.getParameterCount() > 0) {
+ if (preparedStatement.getSqlStatement().getParameterCount() > 0) {
for (int each : nullBitmap.getNullBitmap()) {
payload.writeInt1(each);
}
diff --git a/shardingsphere-db-protocol/shardingsphere-db-protocol-mysql/src/test/java/org/apache/shardingsphere/db/protocol/mysql/packet/command/MySQLMySQLCommandPacketFactoryTest.java b/shardingsphere-db-protocol/shardingsphere-db-protocol-mysql/src/test/java/org/apache/shardingsphere/db/protocol/mysql/packet/command/MySQLMySQLCommandPacketFactoryTest.java
index 6051a3f4bcc..0fdb096d564 100644
--- a/shardingsphere-db-protocol/shardingsphere-db-protocol-mysql/src/test/java/org/apache/shardingsphere/db/protocol/mysql/packet/command/MySQLMySQLCommandPacketFactoryTest.java
+++ b/shardingsphere-db-protocol/shardingsphere-db-protocol-mysql/src/test/java/org/apache/shardingsphere/db/protocol/mysql/packet/command/MySQLMySQLCommandPacketFactoryTest.java
@@ -31,6 +31,7 @@ import org.apache.shardingsphere.db.protocol.mysql.packet.command.query.binary.r
import org.apache.shardingsphere.db.protocol.mysql.packet.command.query.text.fieldlist.MySQLComFieldListPacket;
import org.apache.shardingsphere.db.protocol.mysql.packet.command.query.text.query.MySQLComQueryPacket;
import org.apache.shardingsphere.db.protocol.mysql.payload.MySQLPacketPayload;
+import org.apache.shardingsphere.sql.parser.sql.dialect.statement.mysql.dml.MySQLSelectStatement;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.mockito.Mock;
@@ -81,7 +82,7 @@ public final class MySQLMySQLCommandPacketFactoryTest {
when(payload.readInt1()).thenReturn(MySQLNewParametersBoundFlag.PARAMETER_TYPE_EXIST.getValue());
when(payload.readInt4()).thenReturn(1);
MySQLPreparedStatementRegistry.getInstance().registerConnection(CONNECTION_ID);
- MySQLPreparedStatementRegistry.getInstance().getConnectionPreparedStatements(CONNECTION_ID).prepareStatement("SELECT * FROM t_order", 1);
+ MySQLPreparedStatementRegistry.getInstance().getConnectionPreparedStatements(CONNECTION_ID).prepareStatement("SELECT * FROM t_order", new MySQLSelectStatement());
assertThat(MySQLCommandPacketFactory.newInstance(MySQLCommandPacketType.COM_STMT_EXECUTE, payload, CONNECTION_ID), instanceOf(MySQLComStmtExecutePacket.class));
MySQLPreparedStatementRegistry.getInstance().unregisterConnection(CONNECTION_ID);
}
diff --git a/shardingsphere-db-protocol/shardingsphere-db-protocol-mysql/src/test/java/org/apache/shardingsphere/db/protocol/mysql/packet/command/query/binary/MySQLPreparedStatementRegistryTest.java b/shardingsphere-db-protocol/shardingsphere-db-protocol-mysql/src/test/java/org/apache/shardingsphere/db/protocol/mysql/packet/command/query/binary/MySQLPreparedStatementRegistryTest.java
index 156461b4caf..3808fa3a86b 100644
--- a/shardingsphere-db-protocol/shardingsphere-db-protocol-mysql/src/test/java/org/apache/shardingsphere/db/protocol/mysql/packet/command/query/binary/MySQLPreparedStatementRegistryTest.java
+++ b/shardingsphere-db-protocol/shardingsphere-db-protocol-mysql/src/test/java/org/apache/shardingsphere/db/protocol/mysql/packet/command/query/binary/MySQLPreparedStatementRegistryTest.java
@@ -17,6 +17,7 @@
package org.apache.shardingsphere.db.protocol.mysql.packet.command.query.binary;
+import org.apache.shardingsphere.sql.parser.sql.dialect.statement.mysql.dml.MySQLSelectStatement;
import org.junit.After;
import org.junit.Before;
import org.junit.Test;
@@ -38,32 +39,38 @@ public final class MySQLPreparedStatementRegistryTest {
@Test
public void assertRegisterIfAbsent() {
- assertThat(MySQLPreparedStatementRegistry.getInstance().getConnectionPreparedStatements(CONNECTION_ID).prepareStatement(SQL, 1), is(1));
+ assertThat(MySQLPreparedStatementRegistry.getInstance().getConnectionPreparedStatements(CONNECTION_ID).prepareStatement(SQL, prepareSQLStatement()), is(1));
MySQLPreparedStatement actual = MySQLPreparedStatementRegistry.getInstance().getConnectionPreparedStatements(CONNECTION_ID).get(1);
assertThat(actual.getSql(), is(SQL));
- assertThat(actual.getParameterCount(), is(1));
+ assertThat(actual.getSqlStatement().getParameterCount(), is(1));
}
@Test
public void assertPrepareSameSQL() {
- assertThat(MySQLPreparedStatementRegistry.getInstance().getConnectionPreparedStatements(CONNECTION_ID).prepareStatement(SQL, 1), is(1));
- assertThat(MySQLPreparedStatementRegistry.getInstance().getConnectionPreparedStatements(CONNECTION_ID).prepareStatement(SQL, 1), is(2));
+ assertThat(MySQLPreparedStatementRegistry.getInstance().getConnectionPreparedStatements(CONNECTION_ID).prepareStatement(SQL, prepareSQLStatement()), is(1));
+ assertThat(MySQLPreparedStatementRegistry.getInstance().getConnectionPreparedStatements(CONNECTION_ID).prepareStatement(SQL, prepareSQLStatement()), is(2));
MySQLPreparedStatement actual = MySQLPreparedStatementRegistry.getInstance().getConnectionPreparedStatements(CONNECTION_ID).get(1);
assertThat(actual.getSql(), is(SQL));
- assertThat(actual.getParameterCount(), is(1));
+ assertThat(actual.getSqlStatement().getParameterCount(), is(1));
actual = MySQLPreparedStatementRegistry.getInstance().getConnectionPreparedStatements(CONNECTION_ID).get(1);
assertThat(actual.getSql(), is(SQL));
- assertThat(actual.getParameterCount(), is(1));
+ assertThat(actual.getSqlStatement().getParameterCount(), is(1));
}
@Test
public void assertCloseStatement() {
- MySQLPreparedStatementRegistry.getInstance().getConnectionPreparedStatements(CONNECTION_ID).prepareStatement(SQL, 1);
+ MySQLPreparedStatementRegistry.getInstance().getConnectionPreparedStatements(CONNECTION_ID).prepareStatement(SQL, prepareSQLStatement());
MySQLPreparedStatementRegistry.getInstance().getConnectionPreparedStatements(CONNECTION_ID).closeStatement(1);
MySQLPreparedStatement actual = MySQLPreparedStatementRegistry.getInstance().getConnectionPreparedStatements(CONNECTION_ID).get(1);
assertNull(actual);
}
+ private MySQLSelectStatement prepareSQLStatement() {
+ MySQLSelectStatement result = new MySQLSelectStatement();
+ result.setParameterCount(1);
+ return result;
+ }
+
@After
public void tearDown() {
MySQLPreparedStatementRegistry.getInstance().unregisterConnection(CONNECTION_ID);
diff --git a/shardingsphere-db-protocol/shardingsphere-db-protocol-mysql/src/test/java/org/apache/shardingsphere/db/protocol/mysql/packet/command/query/binary/execute/MySQLComStmtExecutePacketTest.java b/shardingsphere-db-protocol/shardingsphere-db-protocol-mysql/src/test/java/org/apache/shardingsphere/db/protocol/mysql/packet/command/query/binary/execute/MySQLComStmtExecutePacketTest.java
index 9cba1bb3eda..74b7fdd85e2 100644
--- a/shardingsphere-db-protocol/shardingsphere-db-protocol-mysql/src/test/java/org/apache/shardingsphere/db/protocol/mysql/packet/command/query/binary/execute/MySQLComStmtExecutePacketTest.java
+++ b/shardingsphere-db-protocol/shardingsphere-db-protocol-mysql/src/test/java/org/apache/shardingsphere/db/protocol/mysql/packet/command/query/binary/execute/MySQLComStmtExecutePacketTest.java
@@ -19,6 +19,7 @@ package org.apache.shardingsphere.db.protocol.mysql.packet.command.query.binary.
import org.apache.shardingsphere.db.protocol.mysql.packet.command.query.binary.MySQLPreparedStatementRegistry;
import org.apache.shardingsphere.db.protocol.mysql.payload.MySQLPacketPayload;
+import org.apache.shardingsphere.sql.parser.sql.dialect.statement.mysql.dml.MySQLSelectStatement;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
@@ -45,7 +46,9 @@ public final class MySQLComStmtExecutePacketTest {
@Before
public void setup() {
MySQLPreparedStatementRegistry.getInstance().registerConnection(CONNECTION_ID);
- MySQLPreparedStatementRegistry.getInstance().getConnectionPreparedStatements(CONNECTION_ID).prepareStatement("SELECT id FROM tbl WHERE id=?", 1);
+ MySQLSelectStatement sqlStatement = new MySQLSelectStatement();
+ sqlStatement.setParameterCount(1);
+ MySQLPreparedStatementRegistry.getInstance().getConnectionPreparedStatements(CONNECTION_ID).prepareStatement("SELECT id FROM tbl WHERE id=?", sqlStatement);
}
@Test
diff --git a/shardingsphere-proxy/shardingsphere-proxy-frontend/shardingsphere-proxy-frontend-mysql/src/main/java/org/apache/shardingsphere/proxy/frontend/mysql/command/query/binary/execute/MySQLComStmtExecuteExecutor.java b/shardingsphere-proxy/shardingsphere-proxy-frontend/shardingsphere-proxy-frontend-mysql/src/main/java/org/apache/shardingsphere/proxy/frontend/mysql/command/query/binary/execute/MySQLComStmtExecuteExecutor.java
index f3923a42414..e5756a69b52 100644
--- a/shardingsphere-proxy/shardingsphere-proxy-frontend/shardingsphere-proxy-frontend-mysql/src/main/java/org/apache/shardingsphere/proxy/frontend/mysql/command/query/binary/execute/MySQLComStmtExecuteExecutor.java
+++ b/shardingsphere-proxy/shardingsphere-proxy-frontend/shardingsphere-proxy-frontend-mysql/src/main/java/org/apache/shardingsphere/proxy/frontend/mysql/command/query/binary/execute/MySQLComStmtExecuteExecutor.java
@@ -17,7 +17,6 @@
package org.apache.shardingsphere.proxy.frontend.mysql.command.query.binary.execute;
-import com.google.common.base.Preconditions;
import lombok.Getter;
import org.apache.shardingsphere.db.protocol.binary.BinaryCell;
import org.apache.shardingsphere.db.protocol.binary.BinaryRow;
@@ -30,12 +29,10 @@ import org.apache.shardingsphere.db.protocol.packet.DatabasePacket;
import org.apache.shardingsphere.infra.binder.SQLStatementContextFactory;
import org.apache.shardingsphere.infra.binder.statement.SQLStatementContext;
import org.apache.shardingsphere.infra.binder.type.TableAvailable;
-import org.apache.shardingsphere.infra.database.type.DatabaseTypeEngine;
import org.apache.shardingsphere.infra.database.type.DatabaseTypeFactory;
import org.apache.shardingsphere.infra.executor.check.SQLCheckEngine;
import org.apache.shardingsphere.infra.rule.ShardingSphereRule;
import org.apache.shardingsphere.mode.metadata.MetaDataContexts;
-import org.apache.shardingsphere.parser.rule.SQLParserRule;
import org.apache.shardingsphere.proxy.backend.communication.DatabaseCommunicationEngineFactory;
import org.apache.shardingsphere.proxy.backend.communication.SQLStatementDatabaseHolder;
import org.apache.shardingsphere.proxy.backend.communication.jdbc.JDBCDatabaseCommunicationEngine;
@@ -83,10 +80,7 @@ public final class MySQLComStmtExecuteExecutor implements QueryCommandExecutor {
public MySQLComStmtExecuteExecutor(final MySQLComStmtExecutePacket packet, final ConnectionSession connectionSession) throws SQLException {
String databaseName = connectionSession.getDatabaseName();
MetaDataContexts metaDataContexts = ProxyContext.getInstance().getContextManager().getMetaDataContexts();
- Optional<SQLParserRule> sqlParserRule = metaDataContexts.getMetaData().getGlobalRuleMetaData().findSingleRule(SQLParserRule.class);
- Preconditions.checkState(sqlParserRule.isPresent());
- SQLStatement sqlStatement = sqlParserRule.get().getSQLParserEngine(
- DatabaseTypeEngine.getTrunkDatabaseTypeName(metaDataContexts.getMetaData().getDatabases().get(databaseName).getProtocolType())).parse(packet.getSql(), true);
+ SQLStatement sqlStatement = packet.getPreparedStatement().getSqlStatement();
if (AutoCommitUtils.needOpenTransaction(sqlStatement)) {
connectionSession.getBackendConnection().handleAutoCommit();
}
diff --git a/shardingsphere-proxy/shardingsphere-proxy-frontend/shardingsphere-proxy-frontend-mysql/src/main/java/org/apache/shardingsphere/proxy/frontend/mysql/command/query/binary/prepare/MySQLComStmtPrepareExecutor.java b/shardingsphere-proxy/shardingsphere-proxy-frontend/shardingsphere-proxy-frontend-mysql/src/main/java/org/apache/shardingsphere/proxy/frontend/mysql/command/query/binary/prepare/MySQLComStmtPrepareExecutor.java
index 3cadd04f639..f6607625549 100644
--- a/shardingsphere-proxy/shardingsphere-proxy-frontend/shardingsphere-proxy-frontend-mysql/src/main/java/org/apache/shardingsphere/proxy/frontend/mysql/command/query/binary/prepare/MySQLComStmtPrepareExecutor.java
+++ b/shardingsphere-proxy/shardingsphere-proxy-frontend/shardingsphere-proxy-frontend-mysql/src/main/java/org/apache/shardingsphere/proxy/frontend/mysql/command/query/binary/prepare/MySQLComStmtPrepareExecutor.java
@@ -77,10 +77,9 @@ public final class MySQLComStmtPrepareExecutor implements CommandExecutor {
if (!MySQLComStmtPrepareChecker.isStatementAllowed(sqlStatement)) {
throw new UnsupportedPreparedStatementException();
}
- int parameterCount = sqlStatement.getParameterCount();
int projectionCount = getProjectionCount(sqlStatement);
- int statementId = MySQLPreparedStatementRegistry.getInstance().getConnectionPreparedStatements(connectionSession.getConnectionId()).prepareStatement(packet.getSql(), parameterCount);
- return createPackets(statementId, projectionCount, parameterCount);
+ int statementId = MySQLPreparedStatementRegistry.getInstance().getConnectionPreparedStatements(connectionSession.getConnectionId()).prepareStatement(packet.getSql(), sqlStatement);
+ return createPackets(statementId, projectionCount, sqlStatement.getParameterCount());
}
private void failedIfContainsMultiStatements() {
diff --git a/shardingsphere-proxy/shardingsphere-proxy-frontend/shardingsphere-proxy-frontend-mysql/src/test/java/org/apache/shardingsphere/proxy/frontend/mysql/command/MySQLCommandExecutorFactoryTest.java b/shardingsphere-proxy/shardingsphere-proxy-frontend/shardingsphere-proxy-frontend-mysql/src/test/java/org/apache/shardingsphere/proxy/frontend/mysql/command/MySQLCommandExecutorFactoryTest.java
index 63f4e79be4b..d3435b835c0 100644
--- a/shardingsphere-proxy/shardingsphere-proxy-frontend/shardingsphere-proxy-frontend-mysql/src/test/java/org/apache/shardingsphere/proxy/frontend/mysql/command/MySQLCommandExecutorFactoryTest.java
+++ b/shardingsphere-proxy/shardingsphere-proxy-frontend/shardingsphere-proxy-frontend-mysql/src/test/java/org/apache/shardingsphere/proxy/frontend/mysql/command/MySQLCommandExecutorFactoryTest.java
@@ -55,6 +55,8 @@ import org.apache.shardingsphere.proxy.frontend.mysql.command.query.binary.prepa
import org.apache.shardingsphere.proxy.frontend.mysql.command.query.binary.reset.MySQLComStmtResetExecutor;
import org.apache.shardingsphere.proxy.frontend.mysql.command.query.text.fieldlist.MySQLComFieldListPacketExecutor;
import org.apache.shardingsphere.proxy.frontend.mysql.command.query.text.query.MySQLComQueryPacketExecutor;
+import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.item.ProjectionsSegment;
+import org.apache.shardingsphere.sql.parser.sql.dialect.statement.mysql.dml.MySQLSelectStatement;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
@@ -146,8 +148,11 @@ public final class MySQLCommandExecutorFactoryTest extends ProxyContextRestorer
@Test
public void assertNewInstanceWithComStmtExecute() throws SQLException {
- MySQLComStmtExecutePacket packet = mock(MySQLComStmtExecutePacket.class);
+ MySQLComStmtExecutePacket packet = mock(MySQLComStmtExecutePacket.class, RETURNS_DEEP_STUBS);
when(packet.getSql()).thenReturn("SELECT 1");
+ MySQLSelectStatement sqlStatement = new MySQLSelectStatement();
+ sqlStatement.setProjections(new ProjectionsSegment(0, 1));
+ when(packet.getPreparedStatement().getSqlStatement()).thenReturn(sqlStatement);
assertThat(MySQLCommandExecutorFactory.newInstance(MySQLCommandPacketType.COM_STMT_EXECUTE, packet, connectionSession), instanceOf(MySQLComStmtExecuteExecutor.class));
}
diff --git a/shardingsphere-proxy/shardingsphere-proxy-frontend/shardingsphere-proxy-frontend-mysql/src/test/java/org/apache/shardingsphere/proxy/frontend/mysql/command/query/binary/execute/MySQLComStmtExecuteExecutorTest.java b/shardingsphere-proxy/shardingsphere-proxy-frontend/shardingsphere-proxy-frontend-mysql/src/test/java/org/apache/shardingsphere/proxy/frontend/mysql/command/query/binary/execute/MySQLComStmtExecuteExecutorTest.java
index c85c18043ee..b6f75acaed8 100644
--- a/shardingsphere-proxy/shardingsphere-proxy-frontend/shardingsphere-proxy-frontend-mysql/src/test/java/org/apache/shardingsphere/proxy/frontend/mysql/command/query/binary/execute/MySQLComStmtExecuteExecutorTest.java
+++ b/shardingsphere-proxy/shardingsphere-proxy-frontend/shardingsphere-proxy-frontend-mysql/src/test/java/org/apache/shardingsphere/proxy/frontend/mysql/command/query/binary/execute/MySQLComStmtExecuteExecutorTest.java
@@ -41,8 +41,11 @@ import org.apache.shardingsphere.proxy.backend.session.ConnectionSession;
import org.apache.shardingsphere.proxy.backend.text.TextProtocolBackendHandler;
import org.apache.shardingsphere.proxy.frontend.command.executor.ResponseType;
import org.apache.shardingsphere.proxy.frontend.mysql.ProxyContextRestorer;
+import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.item.ProjectionsSegment;
import org.apache.shardingsphere.sql.parser.sql.common.statement.SQLStatement;
import org.apache.shardingsphere.sql.parser.sql.common.statement.tcl.CommitStatement;
+import org.apache.shardingsphere.sql.parser.sql.dialect.statement.mysql.dml.MySQLSelectStatement;
+import org.apache.shardingsphere.sql.parser.sql.dialect.statement.mysql.tcl.MySQLCommitStatement;
import org.apache.shardingsphere.transaction.rule.TransactionRule;
import org.junit.Before;
import org.junit.Test;
@@ -92,7 +95,6 @@ public final class MySQLComStmtExecuteExecutorTest extends ProxyContextRestorer
when(connectionSession.getAttributeMap().attr(MySQLConstants.MYSQL_CHARACTER_SET_ATTRIBUTE_KEY).get()).thenReturn(MySQLCharacterSet.UTF8MB4_GENERAL_CI);
when(connectionSession.getBackendConnection()).thenReturn(backendConnection);
when(backendConnection.getConnectionSession()).thenReturn(connectionSession);
- when(contextManager.getMetaDataContexts().getMetaData().getGlobalRuleMetaData().findSingleRule(SQLParserRule.class)).thenReturn(Optional.of(sqlParserRule));
}
private ShardingSphereDatabase mockDatabase() {
@@ -107,8 +109,9 @@ public final class MySQLComStmtExecuteExecutorTest extends ProxyContextRestorer
public void assertIsQueryResponse() throws NoSuchFieldException, SQLException, IllegalAccessException {
when(connectionSession.getDatabaseName()).thenReturn("logic_db");
when(connectionSession.getDefaultDatabaseName()).thenReturn("logic_db");
- MySQLComStmtExecutePacket packet = mock(MySQLComStmtExecutePacket.class);
+ MySQLComStmtExecutePacket packet = mock(MySQLComStmtExecutePacket.class, RETURNS_DEEP_STUBS);
when(packet.getSql()).thenReturn("SELECT 1");
+ when(packet.getPreparedStatement().getSqlStatement()).thenReturn(prepareSQLStatement());
MySQLComStmtExecuteExecutor mysqlComStmtExecuteExecutor = new MySQLComStmtExecuteExecutor(packet, connectionSession);
MemberAccessor accessor = Plugins.getMemberAccessor();
accessor.set(MySQLComStmtExecuteExecutor.class.getDeclaredField("databaseCommunicationEngine"), mysqlComStmtExecuteExecutor, databaseCommunicationEngine);
@@ -121,8 +124,9 @@ public final class MySQLComStmtExecuteExecutorTest extends ProxyContextRestorer
public void assertIsUpdateResponse() throws NoSuchFieldException, SQLException, IllegalAccessException {
when(connectionSession.getDatabaseName()).thenReturn("logic_db");
when(connectionSession.getDefaultDatabaseName()).thenReturn("logic_db");
- MySQLComStmtExecutePacket packet = mock(MySQLComStmtExecutePacket.class);
+ MySQLComStmtExecutePacket packet = mock(MySQLComStmtExecutePacket.class, RETURNS_DEEP_STUBS);
when(packet.getSql()).thenReturn("SELECT 1");
+ when(packet.getPreparedStatement().getSqlStatement()).thenReturn(prepareSQLStatement());
MySQLComStmtExecuteExecutor mysqlComStmtExecuteExecutor = new MySQLComStmtExecuteExecutor(packet, connectionSession);
MemberAccessor accessor = Plugins.getMemberAccessor();
accessor.set(MySQLComStmtExecuteExecutor.class.getDeclaredField("databaseCommunicationEngine"), mysqlComStmtExecuteExecutor, databaseCommunicationEngine);
@@ -131,12 +135,19 @@ public final class MySQLComStmtExecuteExecutorTest extends ProxyContextRestorer
assertThat(mysqlComStmtExecuteExecutor.getResponseType(), is(ResponseType.UPDATE));
}
+ private MySQLSelectStatement prepareSQLStatement() {
+ MySQLSelectStatement sqlStatement = new MySQLSelectStatement();
+ sqlStatement.setProjections(new ProjectionsSegment(0, 0));
+ return sqlStatement;
+ }
+
@Test
public void assertExecutePreparedCommit() throws SQLException, NoSuchFieldException, IllegalAccessException {
when(connectionSession.getDatabaseName()).thenReturn("logic_db");
when(connectionSession.getDefaultDatabaseName()).thenReturn("logic_db");
- MySQLComStmtExecutePacket packet = mock(MySQLComStmtExecutePacket.class);
+ MySQLComStmtExecutePacket packet = mock(MySQLComStmtExecutePacket.class, RETURNS_DEEP_STUBS);
when(packet.getSql()).thenReturn("commit");
+ when(packet.getPreparedStatement().getSqlStatement()).thenReturn(new MySQLCommitStatement());
MySQLComStmtExecuteExecutor mysqlComStmtExecuteExecutor = new MySQLComStmtExecuteExecutor(packet, connectionSession);
TextProtocolBackendHandler textProtocolBackendHandler = mock(TextProtocolBackendHandler.class);
MemberAccessor accessor = Plugins.getMemberAccessor();