You are viewing a plain text version of this content. The canonical link for it is here.
Posted to notifications@shardingsphere.apache.org by pa...@apache.org on 2022/06/16 05:39:24 UTC
[shardingsphere] branch master updated: Move logic of COM_STMT from packet to executor (#18384)
This is an automated email from the ASF dual-hosted git repository.
panjuan pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/shardingsphere.git
The following commit(s) were added to refs/heads/master by this push:
new 687ab04cef0 Move logic of COM_STMT from packet to executor (#18384)
687ab04cef0 is described below
commit 687ab04cef0969399155e1dff009cfdede74c915
Author: 吴伟杰 <wu...@apache.org>
AuthorDate: Thu Jun 16 13:39:18 2022 +0800
Move logic of COM_STMT from packet to executor (#18384)
* Move logic of COM_STMT from packet to executor
* Update MySQLComStmtExecutePacketTest
* Complete MySQLCommandPacketFactoryTest
* Complete MySQLCommandExecutorFactoryTest
* Complete MySQLComStmtExecuteExecutorTest
* Fix checkstyle in MySQLCommandExecutorFactoryTest
* Make types in MySQLPreparedStatement non-null by default
---
.../query/binary/MySQLPreparedStatement.java | 3 +-
.../binary/execute/MySQLComStmtExecutePacket.java | 65 ++++-------
.../execute/MySQLComStmtExecutePacketTest.java | 76 ++++++------
.../mysql/command/MySQLCommandExecuteEngine.java | 1 -
.../mysql}/command/MySQLCommandPacketFactory.java | 10 +-
.../execute/MySQLComStmtExecuteExecutor.java | 54 +++++----
.../command/MySQLCommandExecutorFactoryTest.java | 10 +-
.../command/MySQLCommandPacketFactoryTest.java | 10 +-
.../execute/MySQLComStmtExecuteExecutorTest.java | 130 +++++++++++++--------
.../ReactiveMySQLComStmtExecuteExecutor.java | 74 ++++++------
10 files changed, 236 insertions(+), 197 deletions(-)
diff --git a/shardingsphere-db-protocol/shardingsphere-db-protocol-mysql/src/main/java/org/apache/shardingsphere/db/protocol/mysql/packet/command/query/binary/MySQLPreparedStatement.java b/shardingsphere-db-protocol/shardingsphere-db-protocol-mysql/src/main/java/org/apache/shardingsphere/db/protocol/mysql/packet/command/query/binary/MySQLPreparedStatement.java
index edf282e8572..1b356b54413 100644
--- a/shardingsphere-db-protocol/shardingsphere-db-protocol-mysql/src/main/java/org/apache/shardingsphere/db/protocol/mysql/packet/command/query/binary/MySQLPreparedStatement.java
+++ b/shardingsphere-db-protocol/shardingsphere-db-protocol-mysql/src/main/java/org/apache/shardingsphere/db/protocol/mysql/packet/command/query/binary/MySQLPreparedStatement.java
@@ -22,6 +22,7 @@ import lombok.RequiredArgsConstructor;
import lombok.Setter;
import org.apache.shardingsphere.sql.parser.sql.common.statement.SQLStatement;
+import java.util.Collections;
import java.util.List;
/**
@@ -36,5 +37,5 @@ public final class MySQLPreparedStatement {
private final SQLStatement sqlStatement;
- private List<MySQLPreparedStatementParameterType> parameterTypes;
+ private List<MySQLPreparedStatementParameterType> parameterTypes = Collections.emptyList();
}
diff --git a/shardingsphere-db-protocol/shardingsphere-db-protocol-mysql/src/main/java/org/apache/shardingsphere/db/protocol/mysql/packet/command/query/binary/execute/MySQLComStmtExecutePacket.java b/shardingsphere-db-protocol/shardingsphere-db-protocol-mysql/src/main/java/org/apache/shardingsphere/db/protocol/mysql/packet/command/query/binary/execute/MySQLComStmtExecutePacket.java
index 0de41c2e8e1..14c2bb672f4 100644
--- a/shardingsphere-db-protocol/shardingsphere-db-protocol-mysql/src/main/java/org/apache/shardingsphere/db/protocol/mysql/packet/command/query/binary/execute/MySQLComStmtExecutePacket.java
+++ b/shardingsphere-db-protocol/shardingsphere-db-protocol-mysql/src/main/java/org/apache/shardingsphere/db/protocol/mysql/packet/command/query/binary/execute/MySQLComStmtExecutePacket.java
@@ -24,9 +24,7 @@ import org.apache.shardingsphere.db.protocol.mysql.constant.MySQLBinaryColumnTyp
import org.apache.shardingsphere.db.protocol.mysql.constant.MySQLNewParametersBoundFlag;
import org.apache.shardingsphere.db.protocol.mysql.packet.command.MySQLCommandPacket;
import org.apache.shardingsphere.db.protocol.mysql.packet.command.MySQLCommandPacketType;
-import org.apache.shardingsphere.db.protocol.mysql.packet.command.query.binary.MySQLPreparedStatement;
import org.apache.shardingsphere.db.protocol.mysql.packet.command.query.binary.MySQLPreparedStatementParameterType;
-import org.apache.shardingsphere.db.protocol.mysql.packet.command.query.binary.MySQLPreparedStatementRegistry;
import org.apache.shardingsphere.db.protocol.mysql.packet.command.query.binary.execute.protocol.MySQLBinaryProtocolValue;
import org.apache.shardingsphere.db.protocol.mysql.packet.command.query.binary.execute.protocol.MySQLBinaryProtocolValueFactory;
import org.apache.shardingsphere.db.protocol.mysql.payload.MySQLPacketPayload;
@@ -41,56 +39,49 @@ import java.util.List;
*
* @see <a href="https://dev.mysql.com/doc/internals/en/com-stmt-execute.html">COM_STMT_EXECUTE</a>
*/
-@ToString(of = {"sql", "parameters"})
+@ToString(of = {"statementId"})
public final class MySQLComStmtExecutePacket extends MySQLCommandPacket {
private static final int ITERATION_COUNT = 1;
private static final int NULL_BITMAP_OFFSET = 0;
- private final int statementId;
+ private final MySQLPacketPayload payload;
@Getter
- private final MySQLPreparedStatement preparedStatement;
+ private final int statementId;
private final int flags;
private final MySQLNullBitmap nullBitmap;
- private final MySQLNewParametersBoundFlag newParametersBoundFlag;
-
@Getter
- private final String sql;
+ private final MySQLNewParametersBoundFlag newParametersBoundFlag;
@Getter
- private final List<Object> parameters;
+ private final List<MySQLPreparedStatementParameterType> newParameterTypes;
- public MySQLComStmtExecutePacket(final MySQLPacketPayload payload, final int connectionId) throws SQLException {
+ public MySQLComStmtExecutePacket(final MySQLPacketPayload payload, final int parameterCount) throws SQLException {
super(MySQLCommandPacketType.COM_STMT_EXECUTE);
+ this.payload = payload;
statementId = payload.readInt4();
- preparedStatement = MySQLPreparedStatementRegistry.getInstance().getConnectionPreparedStatements(connectionId).get(statementId);
flags = payload.readInt1();
Preconditions.checkArgument(ITERATION_COUNT == payload.readInt4());
- int parameterCount = preparedStatement.getSqlStatement().getParameterCount();
- sql = preparedStatement.getSql();
if (parameterCount > 0) {
nullBitmap = new MySQLNullBitmap(parameterCount, NULL_BITMAP_OFFSET);
for (int i = 0; i < nullBitmap.getNullBitmap().length; i++) {
nullBitmap.getNullBitmap()[i] = payload.readInt1();
}
newParametersBoundFlag = MySQLNewParametersBoundFlag.valueOf(payload.readInt1());
- if (MySQLNewParametersBoundFlag.PARAMETER_TYPE_EXIST == newParametersBoundFlag) {
- preparedStatement.setParameterTypes(getParameterTypes(payload, parameterCount));
- }
- parameters = getParameters(payload, parameterCount);
+ newParameterTypes = MySQLNewParametersBoundFlag.PARAMETER_TYPE_EXIST == newParametersBoundFlag ? getNewParameterTypes(parameterCount) : Collections.emptyList();
} else {
nullBitmap = null;
newParametersBoundFlag = null;
- parameters = Collections.emptyList();
+ newParameterTypes = Collections.emptyList();
}
}
- private List<MySQLPreparedStatementParameterType> getParameterTypes(final MySQLPacketPayload payload, final int parameterCount) {
+ private List<MySQLPreparedStatementParameterType> getNewParameterTypes(final int parameterCount) {
List<MySQLPreparedStatementParameterType> result = new ArrayList<>(parameterCount);
for (int parameterIndex = 0; parameterIndex < parameterCount; parameterIndex++) {
MySQLBinaryColumnType columnType = MySQLBinaryColumnType.valueOf(payload.readInt1());
@@ -100,33 +91,19 @@ public final class MySQLComStmtExecutePacket extends MySQLCommandPacket {
return result;
}
- private List<Object> getParameters(final MySQLPacketPayload payload, final int parameterCount) throws SQLException {
- List<Object> result = new ArrayList<>(parameterCount);
- for (int parameterIndex = 0; parameterIndex < parameterCount; parameterIndex++) {
- MySQLBinaryProtocolValue binaryProtocolValue = MySQLBinaryProtocolValueFactory.getBinaryProtocolValue(preparedStatement.getParameterTypes().get(parameterIndex).getColumnType());
+ /**
+ * Read parameter values from packet.
+ *
+ * @param parameterTypes parameter type of values
+ * @return parameter values
+ * @throws SQLException SQL exception
+ */
+ public List<Object> readParameters(final List<MySQLPreparedStatementParameterType> parameterTypes) throws SQLException {
+ List<Object> result = new ArrayList<>(parameterTypes.size());
+ for (int parameterIndex = 0; parameterIndex < parameterTypes.size(); parameterIndex++) {
+ MySQLBinaryProtocolValue binaryProtocolValue = MySQLBinaryProtocolValueFactory.getBinaryProtocolValue(parameterTypes.get(parameterIndex).getColumnType());
result.add(nullBitmap.isNullParameter(parameterIndex) ? null : binaryProtocolValue.read(payload));
}
return result;
}
-
- @Override
- public void doWrite(final MySQLPacketPayload payload) {
- payload.writeInt4(statementId);
- payload.writeInt1(flags);
- payload.writeInt4(ITERATION_COUNT);
- if (preparedStatement.getSqlStatement().getParameterCount() > 0) {
- for (int each : nullBitmap.getNullBitmap()) {
- payload.writeInt1(each);
- }
- payload.writeInt1(newParametersBoundFlag.getValue());
- int count = 0;
- for (Object each : parameters) {
- MySQLPreparedStatementParameterType parameterType = preparedStatement.getParameterTypes().get(count);
- payload.writeInt1(parameterType.getColumnType().getValue());
- payload.writeInt1(parameterType.getUnsignedFlag());
- payload.writeStringLenenc(null == each ? "" : each.toString());
- count++;
- }
- }
- }
}
diff --git a/shardingsphere-db-protocol/shardingsphere-db-protocol-mysql/src/test/java/org/apache/shardingsphere/db/protocol/mysql/packet/command/query/binary/execute/MySQLComStmtExecutePacketTest.java b/shardingsphere-db-protocol/shardingsphere-db-protocol-mysql/src/test/java/org/apache/shardingsphere/db/protocol/mysql/packet/command/query/binary/execute/MySQLComStmtExecutePacketTest.java
index 74b7fdd85e2..2a57d47c9ae 100644
--- a/shardingsphere-db-protocol/shardingsphere-db-protocol-mysql/src/test/java/org/apache/shardingsphere/db/protocol/mysql/packet/command/query/binary/execute/MySQLComStmtExecutePacketTest.java
+++ b/shardingsphere-db-protocol/shardingsphere-db-protocol-mysql/src/test/java/org/apache/shardingsphere/db/protocol/mysql/packet/command/query/binary/execute/MySQLComStmtExecutePacketTest.java
@@ -17,69 +17,71 @@
package org.apache.shardingsphere.db.protocol.mysql.packet.command.query.binary.execute;
+import io.netty.buffer.Unpooled;
+import org.apache.shardingsphere.db.protocol.mysql.constant.MySQLBinaryColumnType;
+import org.apache.shardingsphere.db.protocol.mysql.constant.MySQLNewParametersBoundFlag;
+import org.apache.shardingsphere.db.protocol.mysql.packet.command.query.binary.MySQLPreparedStatementParameterType;
import org.apache.shardingsphere.db.protocol.mysql.packet.command.query.binary.MySQLPreparedStatementRegistry;
import org.apache.shardingsphere.db.protocol.mysql.payload.MySQLPacketPayload;
import org.apache.shardingsphere.sql.parser.sql.dialect.statement.mysql.dml.MySQLSelectStatement;
import org.junit.Before;
import org.junit.Test;
-import org.junit.runner.RunWith;
-import org.mockito.Mock;
-import org.mockito.junit.MockitoJUnitRunner;
+import java.nio.charset.StandardCharsets;
import java.sql.SQLException;
import java.util.Collections;
+import java.util.List;
import static org.hamcrest.CoreMatchers.is;
+import static org.junit.Assert.assertNull;
import static org.junit.Assert.assertThat;
-import static org.mockito.Mockito.times;
-import static org.mockito.Mockito.verify;
-import static org.mockito.Mockito.when;
+import static org.junit.Assert.assertTrue;
-@RunWith(MockitoJUnitRunner.class)
public final class MySQLComStmtExecutePacketTest {
- private static final int CONNECTION_ID = 1;
-
- @Mock
- private MySQLPacketPayload payload;
-
@Before
public void setup() {
- MySQLPreparedStatementRegistry.getInstance().registerConnection(CONNECTION_ID);
+ MySQLPreparedStatementRegistry.getInstance().registerConnection(1);
MySQLSelectStatement sqlStatement = new MySQLSelectStatement();
sqlStatement.setParameterCount(1);
- MySQLPreparedStatementRegistry.getInstance().getConnectionPreparedStatements(CONNECTION_ID).prepareStatement("SELECT id FROM tbl WHERE id=?", sqlStatement);
+ MySQLPreparedStatementRegistry.getInstance().getConnectionPreparedStatements(1).prepareStatement("SELECT id FROM tbl WHERE id=?", sqlStatement);
}
@Test
- public void assertNewWithNotNullParameters() throws SQLException {
- when(payload.readInt4()).thenReturn(1);
- when(payload.readInt1()).thenReturn(0, 0, 1);
- MySQLComStmtExecutePacket actual = new MySQLComStmtExecutePacket(payload, CONNECTION_ID);
- assertThat(actual.getSequenceId(), is(0));
- assertThat(actual.getSql(), is("SELECT id FROM tbl WHERE id=?"));
- assertThat(actual.getParameters(), is(Collections.<Object>singletonList(1)));
+ public void assertNewWithoutParameter() throws SQLException {
+ byte[] data = {0x01, 0x00, 0x00, 0x00, 0x09, 0x01, 0x00, 0x00, 0x00};
+ MySQLPacketPayload payload = new MySQLPacketPayload(Unpooled.wrappedBuffer(data), StandardCharsets.UTF_8);
+ MySQLComStmtExecutePacket actual = new MySQLComStmtExecutePacket(payload, 0);
+ assertThat(actual.getStatementId(), is(1));
+ assertNull(actual.getNewParametersBoundFlag());
+ assertTrue(actual.getNewParameterTypes().isEmpty());
}
@Test
- public void assertNewWithNullParameters() throws SQLException {
- when(payload.readInt4()).thenReturn(1);
- when(payload.readInt1()).thenReturn(0, 1);
- MySQLComStmtExecutePacket actual = new MySQLComStmtExecutePacket(payload, CONNECTION_ID);
- assertThat(actual.getSequenceId(), is(0));
- assertThat(actual.getSql(), is("SELECT id FROM tbl WHERE id=?"));
- assertThat(actual.getParameters(), is(Collections.singletonList(null)));
+ public void assertNewParameterBoundWithNotNullParameters() throws SQLException {
+ byte[] data = {0x01, 0x00, 0x00, 0x00, 0x09, 0x01, 0x00, 0x00, 0x00, 0x00, 0x01, 0x03, 0x00, 0x01, 0x00, 0x00, 0x00};
+ MySQLPacketPayload payload = new MySQLPacketPayload(Unpooled.wrappedBuffer(data), StandardCharsets.UTF_8);
+ MySQLComStmtExecutePacket actual = new MySQLComStmtExecutePacket(payload, 1);
+ assertThat(actual.getStatementId(), is(1));
+ assertThat(actual.getNewParametersBoundFlag(), is(MySQLNewParametersBoundFlag.PARAMETER_TYPE_EXIST));
+ List<MySQLPreparedStatementParameterType> parameterTypes = actual.getNewParameterTypes();
+ assertThat(parameterTypes.size(), is(1));
+ assertThat(parameterTypes.get(0).getColumnType(), is(MySQLBinaryColumnType.MYSQL_TYPE_LONG));
+ assertThat(parameterTypes.get(0).getUnsignedFlag(), is(0));
+ assertThat(actual.readParameters(parameterTypes), is(Collections.<Object>singletonList(1)));
}
@Test
- public void assertWrite() throws SQLException {
- when(payload.readInt4()).thenReturn(1);
- when(payload.readInt1()).thenReturn(0, 1);
- MySQLComStmtExecutePacket actual = new MySQLComStmtExecutePacket(payload, CONNECTION_ID);
- actual.write(payload);
- verify(payload, times(2)).writeInt4(1);
- verify(payload, times(4)).writeInt1(1);
- verify(payload).writeInt1(0);
- verify(payload).writeStringLenenc("");
+ public void assertNewWithNullParameters() throws SQLException {
+ byte[] data = {0x01, 0x00, 0x00, 0x00, 0x09, 0x01, 0x00, 0x00, 0x00, 0x01, 0x01, 0x03, 0x00};
+ MySQLPacketPayload payload = new MySQLPacketPayload(Unpooled.wrappedBuffer(data), StandardCharsets.UTF_8);
+ MySQLComStmtExecutePacket actual = new MySQLComStmtExecutePacket(payload, 1);
+ assertThat(actual.getStatementId(), is(1));
+ assertThat(actual.getNewParametersBoundFlag(), is(MySQLNewParametersBoundFlag.PARAMETER_TYPE_EXIST));
+ List<MySQLPreparedStatementParameterType> parameterTypes = actual.getNewParameterTypes();
+ assertThat(parameterTypes.size(), is(1));
+ assertThat(parameterTypes.get(0).getColumnType(), is(MySQLBinaryColumnType.MYSQL_TYPE_LONG));
+ assertThat(parameterTypes.get(0).getUnsignedFlag(), is(0));
+ assertThat(actual.readParameters(parameterTypes), is(Collections.singletonList(null)));
}
}
diff --git a/shardingsphere-proxy/shardingsphere-proxy-frontend/shardingsphere-proxy-frontend-mysql/src/main/java/org/apache/shardingsphere/proxy/frontend/mysql/command/MySQLCommandExecuteEngine.java b/shardingsphere-proxy/shardingsphere-proxy-frontend/shardingsphere-proxy-frontend-mysql/src/main/java/org/apache/shardingsphere/proxy/frontend/mysql/command/MySQLCommandExecuteEngine.java
index aed9df443eb..ee3ceb7cf5c 100644
--- a/shardingsphere-proxy/shardingsphere-proxy-frontend/shardingsphere-proxy-frontend-mysql/src/main/java/org/apache/shardingsphere/proxy/frontend/mysql/command/MySQLCommandExecuteEngine.java
+++ b/shardingsphere-proxy/shardingsphere-proxy-frontend/shardingsphere-proxy-frontend-mysql/src/main/java/org/apache/shardingsphere/proxy/frontend/mysql/command/MySQLCommandExecuteEngine.java
@@ -19,7 +19,6 @@ package org.apache.shardingsphere.proxy.frontend.mysql.command;
import io.netty.channel.ChannelHandlerContext;
import org.apache.shardingsphere.db.protocol.mysql.packet.command.MySQLCommandPacket;
-import org.apache.shardingsphere.db.protocol.mysql.packet.command.MySQLCommandPacketFactory;
import org.apache.shardingsphere.db.protocol.mysql.packet.command.MySQLCommandPacketType;
import org.apache.shardingsphere.db.protocol.mysql.packet.command.MySQLCommandPacketTypeLoader;
import org.apache.shardingsphere.db.protocol.mysql.packet.generic.MySQLEofPacket;
diff --git a/shardingsphere-db-protocol/shardingsphere-db-protocol-mysql/src/main/java/org/apache/shardingsphere/db/protocol/mysql/packet/command/MySQLCommandPacketFactory.java b/shardingsphere-proxy/shardingsphere-proxy-frontend/shardingsphere-proxy-frontend-mysql/src/main/java/org/apache/shardingsphere/proxy/frontend/mysql/command/MySQLCommandPacketFactory.java
similarity index 81%
rename from shardingsphere-db-protocol/shardingsphere-db-protocol-mysql/src/main/java/org/apache/shardingsphere/db/protocol/mysql/packet/command/MySQLCommandPacketFactory.java
rename to shardingsphere-proxy/shardingsphere-proxy-frontend/shardingsphere-proxy-frontend-mysql/src/main/java/org/apache/shardingsphere/proxy/frontend/mysql/command/MySQLCommandPacketFactory.java
index 986248c9a27..eec4205d8f4 100644
--- a/shardingsphere-db-protocol/shardingsphere-db-protocol-mysql/src/main/java/org/apache/shardingsphere/db/protocol/mysql/packet/command/MySQLCommandPacketFactory.java
+++ b/shardingsphere-proxy/shardingsphere-proxy-frontend/shardingsphere-proxy-frontend-mysql/src/main/java/org/apache/shardingsphere/proxy/frontend/mysql/command/MySQLCommandPacketFactory.java
@@ -15,15 +15,19 @@
* limitations under the License.
*/
-package org.apache.shardingsphere.db.protocol.mysql.packet.command;
+package org.apache.shardingsphere.proxy.frontend.mysql.command;
import lombok.AccessLevel;
import lombok.NoArgsConstructor;
+import org.apache.shardingsphere.db.protocol.mysql.packet.command.MySQLCommandPacket;
+import org.apache.shardingsphere.db.protocol.mysql.packet.command.MySQLCommandPacketType;
import org.apache.shardingsphere.db.protocol.mysql.packet.command.admin.MySQLComSetOptionPacket;
import org.apache.shardingsphere.db.protocol.mysql.packet.command.admin.MySQLUnsupportedCommandPacket;
import org.apache.shardingsphere.db.protocol.mysql.packet.command.admin.initdb.MySQLComInitDbPacket;
import org.apache.shardingsphere.db.protocol.mysql.packet.command.admin.ping.MySQLComPingPacket;
import org.apache.shardingsphere.db.protocol.mysql.packet.command.admin.quit.MySQLComQuitPacket;
+import org.apache.shardingsphere.db.protocol.mysql.packet.command.query.binary.MySQLPreparedStatement;
+import org.apache.shardingsphere.db.protocol.mysql.packet.command.query.binary.MySQLPreparedStatementRegistry;
import org.apache.shardingsphere.db.protocol.mysql.packet.command.query.binary.close.MySQLComStmtClosePacket;
import org.apache.shardingsphere.db.protocol.mysql.packet.command.query.binary.execute.MySQLComStmtExecutePacket;
import org.apache.shardingsphere.db.protocol.mysql.packet.command.query.binary.prepare.MySQLComStmtPreparePacket;
@@ -62,7 +66,9 @@ public final class MySQLCommandPacketFactory {
case COM_STMT_PREPARE:
return new MySQLComStmtPreparePacket(payload);
case COM_STMT_EXECUTE:
- return new MySQLComStmtExecutePacket(payload, connectionId);
+ MySQLPreparedStatement preparedStatement = MySQLPreparedStatementRegistry.getInstance()
+ .getConnectionPreparedStatements(connectionId).get(payload.getByteBuf().getIntLE(payload.getByteBuf().readerIndex()));
+ return new MySQLComStmtExecutePacket(payload, preparedStatement.getSqlStatement().getParameterCount());
case COM_STMT_RESET:
return new MySQLComStmtResetPacket(payload);
case COM_STMT_CLOSE:
diff --git a/shardingsphere-proxy/shardingsphere-proxy-frontend/shardingsphere-proxy-frontend-mysql/src/main/java/org/apache/shardingsphere/proxy/frontend/mysql/command/query/binary/execute/MySQLComStmtExecuteExecutor.java b/shardingsphere-proxy/shardingsphere-proxy-frontend/shardingsphere-proxy-frontend-mysql/src/main/java/org/apache/shardingsphere/proxy/frontend/mysql/command/query/binary/execute/MySQLComStmtExecuteExecutor.java
index e5756a69b52..1646b743c95 100644
--- a/shardingsphere-proxy/shardingsphere-proxy-frontend/shardingsphere-proxy-frontend-mysql/src/main/java/org/apache/shardingsphere/proxy/frontend/mysql/command/query/binary/execute/MySQLComStmtExecuteExecutor.java
+++ b/shardingsphere-proxy/shardingsphere-proxy-frontend/shardingsphere-proxy-frontend-mysql/src/main/java/org/apache/shardingsphere/proxy/frontend/mysql/command/query/binary/execute/MySQLComStmtExecuteExecutor.java
@@ -18,11 +18,15 @@
package org.apache.shardingsphere.proxy.frontend.mysql.command.query.binary.execute;
import lombok.Getter;
+import lombok.RequiredArgsConstructor;
import org.apache.shardingsphere.db.protocol.binary.BinaryCell;
import org.apache.shardingsphere.db.protocol.binary.BinaryRow;
import org.apache.shardingsphere.db.protocol.mysql.constant.MySQLBinaryColumnType;
import org.apache.shardingsphere.db.protocol.mysql.constant.MySQLConstants;
+import org.apache.shardingsphere.db.protocol.mysql.constant.MySQLNewParametersBoundFlag;
import org.apache.shardingsphere.db.protocol.mysql.packet.MySQLPacket;
+import org.apache.shardingsphere.db.protocol.mysql.packet.command.query.binary.MySQLPreparedStatement;
+import org.apache.shardingsphere.db.protocol.mysql.packet.command.query.binary.MySQLPreparedStatementRegistry;
import org.apache.shardingsphere.db.protocol.mysql.packet.command.query.binary.execute.MySQLBinaryResultSetRowPacket;
import org.apache.shardingsphere.db.protocol.mysql.packet.command.query.binary.execute.MySQLComStmtExecutePacket;
import org.apache.shardingsphere.db.protocol.packet.DatabasePacket;
@@ -64,44 +68,58 @@ import java.util.Optional;
/**
* COM_STMT_EXECUTE command executor for MySQL.
*/
+@RequiredArgsConstructor
public final class MySQLComStmtExecuteExecutor implements QueryCommandExecutor {
- private final JDBCDatabaseCommunicationEngine databaseCommunicationEngine;
+ private final MySQLComStmtExecutePacket packet;
- private final TextProtocolBackendHandler textProtocolBackendHandler;
+ private final ConnectionSession connectionSession;
- private final int characterSet;
+ private JDBCDatabaseCommunicationEngine databaseCommunicationEngine;
+
+ private TextProtocolBackendHandler textProtocolBackendHandler;
@Getter
- private volatile ResponseType responseType;
+ private ResponseType responseType;
private int currentSequenceId;
- public MySQLComStmtExecuteExecutor(final MySQLComStmtExecutePacket packet, final ConnectionSession connectionSession) throws SQLException {
+ @Override
+ public Collection<DatabasePacket<?>> execute() throws SQLException {
+ MySQLPreparedStatement preparedStatement = updateAndGetPreparedStatement();
String databaseName = connectionSession.getDatabaseName();
MetaDataContexts metaDataContexts = ProxyContext.getInstance().getContextManager().getMetaDataContexts();
- SQLStatement sqlStatement = packet.getPreparedStatement().getSqlStatement();
+ SQLStatement sqlStatement = preparedStatement.getSqlStatement();
if (AutoCommitUtils.needOpenTransaction(sqlStatement)) {
connectionSession.getBackendConnection().handleAutoCommit();
}
- SQLStatementContext<?> sqlStatementContext = SQLStatementContextFactory.newInstance(metaDataContexts.getMetaData().getDatabases(), packet.getParameters(),
+ List<Object> parameters = packet.readParameters(preparedStatement.getParameterTypes());
+ SQLStatementContext<?> sqlStatementContext = SQLStatementContextFactory.newInstance(metaDataContexts.getMetaData().getDatabases(), parameters,
sqlStatement, connectionSession.getDefaultDatabaseName());
// TODO optimize SQLStatementDatabaseHolder
if (sqlStatementContext instanceof TableAvailable) {
((TableAvailable) sqlStatementContext).getTablesContext().getDatabaseName().ifPresent(SQLStatementDatabaseHolder::set);
}
SQLCheckEngine.check(sqlStatement, Collections.emptyList(), getRules(databaseName), databaseName, metaDataContexts.getMetaData().getDatabases(), connectionSession.getGrantee());
- characterSet = connectionSession.getAttributeMap().attr(MySQLConstants.MYSQL_CHARACTER_SET_ATTRIBUTE_KEY).get().getId();
// TODO Refactor the following branch
if (sqlStatement instanceof TCLStatement) {
- databaseCommunicationEngine = null;
textProtocolBackendHandler =
- TextProtocolBackendHandlerFactory.newInstance(DatabaseTypeFactory.getInstance("MySQL"), packet.getSql(), () -> Optional.of(sqlStatement), connectionSession);
- return;
+ TextProtocolBackendHandlerFactory.newInstance(DatabaseTypeFactory.getInstance("MySQL"), preparedStatement.getSql(), () -> Optional.of(sqlStatement), connectionSession);
+ } else {
+ databaseCommunicationEngine = DatabaseCommunicationEngineFactory.getInstance().newBinaryProtocolInstance(sqlStatementContext, preparedStatement.getSql(), parameters,
+ connectionSession.getBackendConnection());
}
- textProtocolBackendHandler = null;
- databaseCommunicationEngine = DatabaseCommunicationEngineFactory.getInstance().newBinaryProtocolInstance(sqlStatementContext, packet.getSql(), packet.getParameters(),
- connectionSession.getBackendConnection());
+ ResponseHeader responseHeader = null != databaseCommunicationEngine ? databaseCommunicationEngine.execute() : textProtocolBackendHandler.execute();
+ int characterSet = connectionSession.getAttributeMap().attr(MySQLConstants.MYSQL_CHARACTER_SET_ATTRIBUTE_KEY).get().getId();
+ return responseHeader instanceof QueryResponseHeader ? processQuery((QueryResponseHeader) responseHeader, characterSet) : processUpdate((UpdateResponseHeader) responseHeader);
+ }
+
+ private MySQLPreparedStatement updateAndGetPreparedStatement() {
+ MySQLPreparedStatement result = MySQLPreparedStatementRegistry.getInstance().getConnectionPreparedStatements(connectionSession.getConnectionId()).get(packet.getStatementId());
+ if (MySQLNewParametersBoundFlag.PARAMETER_TYPE_EXIST == packet.getNewParametersBoundFlag()) {
+ result.setParameterTypes(packet.getNewParameterTypes());
+ }
+ return result;
}
private static Collection<ShardingSphereRule> getRules(final String databaseName) {
@@ -111,13 +129,7 @@ public final class MySQLComStmtExecuteExecutor implements QueryCommandExecutor {
return result;
}
- @Override
- public Collection<DatabasePacket<?>> execute() throws SQLException {
- ResponseHeader responseHeader = null != databaseCommunicationEngine ? databaseCommunicationEngine.execute() : textProtocolBackendHandler.execute();
- return responseHeader instanceof QueryResponseHeader ? processQuery((QueryResponseHeader) responseHeader) : processUpdate((UpdateResponseHeader) responseHeader);
- }
-
- private Collection<DatabasePacket<?>> processQuery(final QueryResponseHeader queryResponseHeader) {
+ private Collection<DatabasePacket<?>> processQuery(final QueryResponseHeader queryResponseHeader, final int characterSet) {
responseType = ResponseType.QUERY;
Collection<DatabasePacket<?>> result = ResponsePacketBuilder.buildQueryResponsePackets(queryResponseHeader, characterSet);
currentSequenceId = result.size();
diff --git a/shardingsphere-proxy/shardingsphere-proxy-frontend/shardingsphere-proxy-frontend-mysql/src/test/java/org/apache/shardingsphere/proxy/frontend/mysql/command/MySQLCommandExecutorFactoryTest.java b/shardingsphere-proxy/shardingsphere-proxy-frontend/shardingsphere-proxy-frontend-mysql/src/test/java/org/apache/shardingsphere/proxy/frontend/mysql/command/MySQLCommandExecutorFactoryTest.java
index d3435b835c0..759995587eb 100644
--- a/shardingsphere-proxy/shardingsphere-proxy-frontend/shardingsphere-proxy-frontend-mysql/src/test/java/org/apache/shardingsphere/proxy/frontend/mysql/command/MySQLCommandExecutorFactoryTest.java
+++ b/shardingsphere-proxy/shardingsphere-proxy-frontend/shardingsphere-proxy-frontend-mysql/src/test/java/org/apache/shardingsphere/proxy/frontend/mysql/command/MySQLCommandExecutorFactoryTest.java
@@ -55,8 +55,6 @@ import org.apache.shardingsphere.proxy.frontend.mysql.command.query.binary.prepa
import org.apache.shardingsphere.proxy.frontend.mysql.command.query.binary.reset.MySQLComStmtResetExecutor;
import org.apache.shardingsphere.proxy.frontend.mysql.command.query.text.fieldlist.MySQLComFieldListPacketExecutor;
import org.apache.shardingsphere.proxy.frontend.mysql.command.query.text.query.MySQLComQueryPacketExecutor;
-import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.item.ProjectionsSegment;
-import org.apache.shardingsphere.sql.parser.sql.dialect.statement.mysql.dml.MySQLSelectStatement;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
@@ -148,12 +146,8 @@ public final class MySQLCommandExecutorFactoryTest extends ProxyContextRestorer
@Test
public void assertNewInstanceWithComStmtExecute() throws SQLException {
- MySQLComStmtExecutePacket packet = mock(MySQLComStmtExecutePacket.class, RETURNS_DEEP_STUBS);
- when(packet.getSql()).thenReturn("SELECT 1");
- MySQLSelectStatement sqlStatement = new MySQLSelectStatement();
- sqlStatement.setProjections(new ProjectionsSegment(0, 1));
- when(packet.getPreparedStatement().getSqlStatement()).thenReturn(sqlStatement);
- assertThat(MySQLCommandExecutorFactory.newInstance(MySQLCommandPacketType.COM_STMT_EXECUTE, packet, connectionSession), instanceOf(MySQLComStmtExecuteExecutor.class));
+ assertThat(MySQLCommandExecutorFactory.newInstance(MySQLCommandPacketType.COM_STMT_EXECUTE, mock(MySQLComStmtExecutePacket.class), connectionSession),
+ instanceOf(MySQLComStmtExecuteExecutor.class));
}
@Test
diff --git a/shardingsphere-db-protocol/shardingsphere-db-protocol-mysql/src/test/java/org/apache/shardingsphere/db/protocol/mysql/packet/command/MySQLMySQLCommandPacketFactoryTest.java b/shardingsphere-proxy/shardingsphere-proxy-frontend/shardingsphere-proxy-frontend-mysql/src/test/java/org/apache/shardingsphere/proxy/frontend/mysql/command/MySQLCommandPacketFactoryTest.java
similarity index 96%
rename from shardingsphere-db-protocol/shardingsphere-db-protocol-mysql/src/test/java/org/apache/shardingsphere/db/protocol/mysql/packet/command/MySQLMySQLCommandPacketFactoryTest.java
rename to shardingsphere-proxy/shardingsphere-proxy-frontend/shardingsphere-proxy-frontend-mysql/src/test/java/org/apache/shardingsphere/proxy/frontend/mysql/command/MySQLCommandPacketFactoryTest.java
index 0fdb096d564..99ef170e35f 100644
--- a/shardingsphere-db-protocol/shardingsphere-db-protocol-mysql/src/test/java/org/apache/shardingsphere/db/protocol/mysql/packet/command/MySQLMySQLCommandPacketFactoryTest.java
+++ b/shardingsphere-proxy/shardingsphere-proxy-frontend/shardingsphere-proxy-frontend-mysql/src/test/java/org/apache/shardingsphere/proxy/frontend/mysql/command/MySQLCommandPacketFactoryTest.java
@@ -15,9 +15,10 @@
* limitations under the License.
*/
-package org.apache.shardingsphere.db.protocol.mysql.packet.command;
+package org.apache.shardingsphere.proxy.frontend.mysql.command;
import org.apache.shardingsphere.db.protocol.mysql.constant.MySQLNewParametersBoundFlag;
+import org.apache.shardingsphere.db.protocol.mysql.packet.command.MySQLCommandPacketType;
import org.apache.shardingsphere.db.protocol.mysql.packet.command.admin.MySQLComSetOptionPacket;
import org.apache.shardingsphere.db.protocol.mysql.packet.command.admin.MySQLUnsupportedCommandPacket;
import org.apache.shardingsphere.db.protocol.mysql.packet.command.admin.initdb.MySQLComInitDbPacket;
@@ -34,6 +35,7 @@ import org.apache.shardingsphere.db.protocol.mysql.payload.MySQLPacketPayload;
import org.apache.shardingsphere.sql.parser.sql.dialect.statement.mysql.dml.MySQLSelectStatement;
import org.junit.Test;
import org.junit.runner.RunWith;
+import org.mockito.Answers;
import org.mockito.Mock;
import org.mockito.junit.MockitoJUnitRunner;
@@ -41,14 +43,15 @@ import java.sql.SQLException;
import static org.hamcrest.CoreMatchers.instanceOf;
import static org.junit.Assert.assertThat;
+import static org.mockito.ArgumentMatchers.anyInt;
import static org.mockito.Mockito.when;
@RunWith(MockitoJUnitRunner.class)
-public final class MySQLMySQLCommandPacketFactoryTest {
+public final class MySQLCommandPacketFactoryTest {
private static final int CONNECTION_ID = 1;
- @Mock
+ @Mock(answer = Answers.RETURNS_DEEP_STUBS)
private MySQLPacketPayload payload;
@Test
@@ -81,6 +84,7 @@ public final class MySQLMySQLCommandPacketFactoryTest {
public void assertNewInstanceWithComStmtExecutePacket() throws SQLException {
when(payload.readInt1()).thenReturn(MySQLNewParametersBoundFlag.PARAMETER_TYPE_EXIST.getValue());
when(payload.readInt4()).thenReturn(1);
+ when(payload.getByteBuf().getIntLE(anyInt())).thenReturn(1);
MySQLPreparedStatementRegistry.getInstance().registerConnection(CONNECTION_ID);
MySQLPreparedStatementRegistry.getInstance().getConnectionPreparedStatements(CONNECTION_ID).prepareStatement("SELECT * FROM t_order", new MySQLSelectStatement());
assertThat(MySQLCommandPacketFactory.newInstance(MySQLCommandPacketType.COM_STMT_EXECUTE, payload, CONNECTION_ID), instanceOf(MySQLComStmtExecutePacket.class));
diff --git a/shardingsphere-proxy/shardingsphere-proxy-frontend/shardingsphere-proxy-frontend-mysql/src/test/java/org/apache/shardingsphere/proxy/frontend/mysql/command/query/binary/execute/MySQLComStmtExecuteExecutorTest.java b/shardingsphere-proxy/shardingsphere-proxy-frontend/shardingsphere-proxy-frontend-mysql/src/test/java/org/apache/shardingsphere/proxy/frontend/mysql/command/query/binary/execute/MySQLComStmtExecuteExecutorTest.java
index b6f75acaed8..eb875fe6104 100644
--- a/shardingsphere-proxy/shardingsphere-proxy-frontend/shardingsphere-proxy-frontend-mysql/src/test/java/org/apache/shardingsphere/proxy/frontend/mysql/command/query/binary/execute/MySQLComStmtExecuteExecutorTest.java
+++ b/shardingsphere-proxy/shardingsphere-proxy-frontend/shardingsphere-proxy-frontend-mysql/src/test/java/org/apache/shardingsphere/proxy/frontend/mysql/command/query/binary/execute/MySQLComStmtExecuteExecutorTest.java
@@ -19,7 +19,14 @@ package org.apache.shardingsphere.proxy.frontend.mysql.command.query.binary.exec
import org.apache.shardingsphere.db.protocol.mysql.constant.MySQLCharacterSet;
import org.apache.shardingsphere.db.protocol.mysql.constant.MySQLConstants;
+import org.apache.shardingsphere.db.protocol.mysql.packet.command.query.MySQLColumnDefinition41Packet;
+import org.apache.shardingsphere.db.protocol.mysql.packet.command.query.MySQLFieldCountPacket;
+import org.apache.shardingsphere.db.protocol.mysql.packet.command.query.binary.MySQLPreparedStatementRegistry;
import org.apache.shardingsphere.db.protocol.mysql.packet.command.query.binary.execute.MySQLComStmtExecutePacket;
+import org.apache.shardingsphere.db.protocol.mysql.packet.generic.MySQLEofPacket;
+import org.apache.shardingsphere.db.protocol.mysql.packet.generic.MySQLOKPacket;
+import org.apache.shardingsphere.db.protocol.packet.DatabasePacket;
+import org.apache.shardingsphere.infra.binder.statement.SQLStatementContext;
import org.apache.shardingsphere.infra.config.props.ConfigurationProperties;
import org.apache.shardingsphere.infra.database.type.dialect.MySQLDatabaseType;
import org.apache.shardingsphere.infra.federation.optimizer.context.OptimizerContext;
@@ -29,8 +36,7 @@ import org.apache.shardingsphere.infra.metadata.database.rule.ShardingSphereRule
import org.apache.shardingsphere.mode.manager.ContextManager;
import org.apache.shardingsphere.mode.metadata.MetaDataContexts;
import org.apache.shardingsphere.mode.metadata.persist.MetaDataPersistService;
-import org.apache.shardingsphere.parser.rule.SQLParserRule;
-import org.apache.shardingsphere.parser.rule.builder.DefaultSQLParserRuleConfigurationBuilder;
+import org.apache.shardingsphere.proxy.backend.communication.DatabaseCommunicationEngineFactory;
import org.apache.shardingsphere.proxy.backend.communication.jdbc.JDBCDatabaseCommunicationEngine;
import org.apache.shardingsphere.proxy.backend.communication.jdbc.connection.JDBCBackendConnection;
import org.apache.shardingsphere.proxy.backend.context.ProxyContext;
@@ -39,39 +45,48 @@ import org.apache.shardingsphere.proxy.backend.response.header.query.QueryRespon
import org.apache.shardingsphere.proxy.backend.response.header.update.UpdateResponseHeader;
import org.apache.shardingsphere.proxy.backend.session.ConnectionSession;
import org.apache.shardingsphere.proxy.backend.text.TextProtocolBackendHandler;
+import org.apache.shardingsphere.proxy.backend.text.TextProtocolBackendHandlerFactory;
import org.apache.shardingsphere.proxy.frontend.command.executor.ResponseType;
import org.apache.shardingsphere.proxy.frontend.mysql.ProxyContextRestorer;
+import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.assignment.ColumnAssignmentSegment;
+import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.assignment.SetAssignmentSegment;
+import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.column.ColumnSegment;
+import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.simple.ParameterMarkerExpressionSegment;
import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.item.ProjectionsSegment;
-import org.apache.shardingsphere.sql.parser.sql.common.statement.SQLStatement;
-import org.apache.shardingsphere.sql.parser.sql.common.statement.tcl.CommitStatement;
+import org.apache.shardingsphere.sql.parser.sql.common.value.identifier.IdentifierValue;
import org.apache.shardingsphere.sql.parser.sql.dialect.statement.mysql.dml.MySQLSelectStatement;
+import org.apache.shardingsphere.sql.parser.sql.dialect.statement.mysql.dml.MySQLUpdateStatement;
import org.apache.shardingsphere.sql.parser.sql.dialect.statement.mysql.tcl.MySQLCommitStatement;
-import org.apache.shardingsphere.transaction.rule.TransactionRule;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.mockito.Answers;
import org.mockito.Mock;
-import org.mockito.internal.configuration.plugins.Plugins;
+import org.mockito.MockedStatic;
import org.mockito.junit.MockitoJUnitRunner;
-import org.mockito.plugins.MemberAccessor;
import java.sql.SQLException;
import java.util.Collections;
-import java.util.Optional;
+import java.util.Iterator;
import java.util.Properties;
+import java.util.function.Supplier;
+import static org.hamcrest.CoreMatchers.instanceOf;
import static org.hamcrest.CoreMatchers.is;
+import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertThat;
+import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.ArgumentMatchers.anyList;
+import static org.mockito.ArgumentMatchers.anyString;
+import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.RETURNS_DEEP_STUBS;
import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.mockStatic;
import static org.mockito.Mockito.when;
@RunWith(MockitoJUnitRunner.class)
public final class MySQLComStmtExecuteExecutorTest extends ProxyContextRestorer {
- private final SQLParserRule sqlParserRule = new SQLParserRule(new DefaultSQLParserRuleConfigurationBuilder().build());
-
@Mock
private JDBCDatabaseCommunicationEngine databaseCommunicationEngine;
@@ -85,16 +100,21 @@ public final class MySQLComStmtExecuteExecutorTest extends ProxyContextRestorer
public void setUp() {
ShardingSphereDatabase database = mockDatabase();
ShardingSphereRuleMetaData metaData = mock(ShardingSphereRuleMetaData.class);
- when(metaData.findSingleRule(TransactionRule.class)).thenReturn(Optional.of(mock(TransactionRule.class)));
MetaDataContexts metaDataContexts = new MetaDataContexts(mock(MetaDataPersistService.class),
new ShardingSphereMetaData(Collections.singletonMap("logic_db", database), metaData, new ConfigurationProperties(new Properties())),
mock(OptimizerContext.class, RETURNS_DEEP_STUBS));
ContextManager contextManager = mock(ContextManager.class, RETURNS_DEEP_STUBS);
when(contextManager.getMetaDataContexts()).thenReturn(metaDataContexts);
ProxyContext.init(contextManager);
+ when(connectionSession.getConnectionId()).thenReturn(1);
when(connectionSession.getAttributeMap().attr(MySQLConstants.MYSQL_CHARACTER_SET_ATTRIBUTE_KEY).get()).thenReturn(MySQLCharacterSet.UTF8MB4_GENERAL_CI);
when(connectionSession.getBackendConnection()).thenReturn(backendConnection);
- when(backendConnection.getConnectionSession()).thenReturn(connectionSession);
+ when(connectionSession.getDatabaseName()).thenReturn("logic_db");
+ when(connectionSession.getDefaultDatabaseName()).thenReturn("logic_db");
+ MySQLPreparedStatementRegistry.getInstance().registerConnection(1);
+ MySQLPreparedStatementRegistry.getInstance().getConnectionPreparedStatements(1).prepareStatement("select * from tbl where id = ?", prepareSelectStatement());
+ MySQLPreparedStatementRegistry.getInstance().getConnectionPreparedStatements(1).prepareStatement("update tbl set col=1 where id = ?", prepareUpdateStatement());
+ MySQLPreparedStatementRegistry.getInstance().getConnectionPreparedStatements(1).prepareStatement("commit", new MySQLCommitStatement());
}
private ShardingSphereDatabase mockDatabase() {
@@ -105,55 +125,71 @@ public final class MySQLComStmtExecuteExecutorTest extends ProxyContextRestorer
return result;
}
+ private MySQLSelectStatement prepareSelectStatement() {
+ MySQLSelectStatement sqlStatement = new MySQLSelectStatement();
+ sqlStatement.setProjections(new ProjectionsSegment(0, 0));
+ return sqlStatement;
+ }
+
+ private MySQLUpdateStatement prepareUpdateStatement() {
+ MySQLUpdateStatement result = new MySQLUpdateStatement();
+ ColumnSegment columnSegment = new ColumnSegment(0, 0, new IdentifierValue("col"));
+ ColumnAssignmentSegment columnAssignmentSegment = new ColumnAssignmentSegment(0, 0, Collections.singletonList(columnSegment), new ParameterMarkerExpressionSegment(0, 0, 0));
+ result.setSetAssignment(new SetAssignmentSegment(0, 0, Collections.singletonList(columnAssignmentSegment)));
+ return result;
+ }
+
@Test
- public void assertIsQueryResponse() throws NoSuchFieldException, SQLException, IllegalAccessException {
- when(connectionSession.getDatabaseName()).thenReturn("logic_db");
- when(connectionSession.getDefaultDatabaseName()).thenReturn("logic_db");
- MySQLComStmtExecutePacket packet = mock(MySQLComStmtExecutePacket.class, RETURNS_DEEP_STUBS);
- when(packet.getSql()).thenReturn("SELECT 1");
- when(packet.getPreparedStatement().getSqlStatement()).thenReturn(prepareSQLStatement());
+ public void assertIsQueryResponse() throws SQLException {
+ MySQLComStmtExecutePacket packet = mock(MySQLComStmtExecutePacket.class);
+ when(packet.getStatementId()).thenReturn(1);
MySQLComStmtExecuteExecutor mysqlComStmtExecuteExecutor = new MySQLComStmtExecuteExecutor(packet, connectionSession);
- MemberAccessor accessor = Plugins.getMemberAccessor();
- accessor.set(MySQLComStmtExecuteExecutor.class.getDeclaredField("databaseCommunicationEngine"), mysqlComStmtExecuteExecutor, databaseCommunicationEngine);
when(databaseCommunicationEngine.execute()).thenReturn(new QueryResponseHeader(Collections.singletonList(mock(QueryHeader.class))));
- mysqlComStmtExecuteExecutor.execute();
+ Iterator<DatabasePacket<?>> actual;
+ try (MockedStatic<DatabaseCommunicationEngineFactory> mockedStatic = mockStatic(DatabaseCommunicationEngineFactory.class, RETURNS_DEEP_STUBS)) {
+ mockedStatic.when(() -> DatabaseCommunicationEngineFactory.getInstance().newBinaryProtocolInstance(any(SQLStatementContext.class), anyString(), anyList(), eq(backendConnection)))
+ .thenReturn(databaseCommunicationEngine);
+ actual = mysqlComStmtExecuteExecutor.execute().iterator();
+ }
assertThat(mysqlComStmtExecuteExecutor.getResponseType(), is(ResponseType.QUERY));
+ assertThat(actual.next(), instanceOf(MySQLFieldCountPacket.class));
+ assertThat(actual.next(), instanceOf(MySQLColumnDefinition41Packet.class));
+ assertThat(actual.next(), instanceOf(MySQLEofPacket.class));
+ assertFalse(actual.hasNext());
}
@Test
- public void assertIsUpdateResponse() throws NoSuchFieldException, SQLException, IllegalAccessException {
- when(connectionSession.getDatabaseName()).thenReturn("logic_db");
- when(connectionSession.getDefaultDatabaseName()).thenReturn("logic_db");
- MySQLComStmtExecutePacket packet = mock(MySQLComStmtExecutePacket.class, RETURNS_DEEP_STUBS);
- when(packet.getSql()).thenReturn("SELECT 1");
- when(packet.getPreparedStatement().getSqlStatement()).thenReturn(prepareSQLStatement());
+ public void assertIsUpdateResponse() throws SQLException {
+ MySQLComStmtExecutePacket packet = mock(MySQLComStmtExecutePacket.class);
+ when(packet.getStatementId()).thenReturn(2);
MySQLComStmtExecuteExecutor mysqlComStmtExecuteExecutor = new MySQLComStmtExecuteExecutor(packet, connectionSession);
- MemberAccessor accessor = Plugins.getMemberAccessor();
- accessor.set(MySQLComStmtExecuteExecutor.class.getDeclaredField("databaseCommunicationEngine"), mysqlComStmtExecuteExecutor, databaseCommunicationEngine);
- when(databaseCommunicationEngine.execute()).thenReturn(new UpdateResponseHeader(mock(SQLStatement.class)));
- mysqlComStmtExecuteExecutor.execute();
+ when(databaseCommunicationEngine.execute()).thenReturn(new UpdateResponseHeader(new MySQLUpdateStatement()));
+ Iterator<DatabasePacket<?>> actual;
+ try (MockedStatic<DatabaseCommunicationEngineFactory> mockedStatic = mockStatic(DatabaseCommunicationEngineFactory.class, RETURNS_DEEP_STUBS)) {
+ mockedStatic.when(() -> DatabaseCommunicationEngineFactory.getInstance().newBinaryProtocolInstance(any(SQLStatementContext.class), anyString(), anyList(), eq(backendConnection)))
+ .thenReturn(databaseCommunicationEngine);
+ actual = mysqlComStmtExecuteExecutor.execute().iterator();
+ }
assertThat(mysqlComStmtExecuteExecutor.getResponseType(), is(ResponseType.UPDATE));
- }
-
- private MySQLSelectStatement prepareSQLStatement() {
- MySQLSelectStatement sqlStatement = new MySQLSelectStatement();
- sqlStatement.setProjections(new ProjectionsSegment(0, 0));
- return sqlStatement;
+ assertThat(actual.next(), instanceOf(MySQLOKPacket.class));
+ assertFalse(actual.hasNext());
}
@Test
- public void assertExecutePreparedCommit() throws SQLException, NoSuchFieldException, IllegalAccessException {
- when(connectionSession.getDatabaseName()).thenReturn("logic_db");
- when(connectionSession.getDefaultDatabaseName()).thenReturn("logic_db");
- MySQLComStmtExecutePacket packet = mock(MySQLComStmtExecutePacket.class, RETURNS_DEEP_STUBS);
- when(packet.getSql()).thenReturn("commit");
- when(packet.getPreparedStatement().getSqlStatement()).thenReturn(new MySQLCommitStatement());
+ public void assertExecutePreparedCommit() throws SQLException {
+ MySQLComStmtExecutePacket packet = mock(MySQLComStmtExecutePacket.class);
+ when(packet.getStatementId()).thenReturn(3);
MySQLComStmtExecuteExecutor mysqlComStmtExecuteExecutor = new MySQLComStmtExecuteExecutor(packet, connectionSession);
TextProtocolBackendHandler textProtocolBackendHandler = mock(TextProtocolBackendHandler.class);
- MemberAccessor accessor = Plugins.getMemberAccessor();
- accessor.set(MySQLComStmtExecuteExecutor.class.getDeclaredField("textProtocolBackendHandler"), mysqlComStmtExecuteExecutor, textProtocolBackendHandler);
- when(textProtocolBackendHandler.execute()).thenReturn(new UpdateResponseHeader(mock(CommitStatement.class)));
- mysqlComStmtExecuteExecutor.execute();
+ when(textProtocolBackendHandler.execute()).thenReturn(new UpdateResponseHeader(new MySQLCommitStatement()));
+ Iterator<DatabasePacket<?>> actual;
+ try (MockedStatic<TextProtocolBackendHandlerFactory> mockedStatic = mockStatic(TextProtocolBackendHandlerFactory.class)) {
+ mockedStatic.when(() -> TextProtocolBackendHandlerFactory.newInstance(any(MySQLDatabaseType.class), eq("commit"), any(Supplier.class), eq(connectionSession)))
+ .thenReturn(textProtocolBackendHandler);
+ actual = mysqlComStmtExecuteExecutor.execute().iterator();
+ }
assertThat(mysqlComStmtExecuteExecutor.getResponseType(), is(ResponseType.UPDATE));
+ assertThat(actual.next(), instanceOf(MySQLOKPacket.class));
+ assertFalse(actual.hasNext());
}
}
diff --git a/shardingsphere-proxy/shardingsphere-proxy-frontend/shardingsphere-proxy-frontend-reactive-mysql/src/main/java/org/apache/shardingsphere/proxy/frontend/reactive/mysql/command/query/binary/execute/ReactiveMySQLComStmtExecuteExecutor.java b/shardingsphere-proxy/shardingsphere-proxy-frontend/shardingsphere-proxy-frontend-reactive-mysql/src/main/java/org/apache/shardingsphere/proxy/frontend/reactive/mysql/command/query/binary/execute/ReactiveMySQLComStmtExecuteExecutor.java
index caaf74630f0..dc56c0e77a4 100644
--- a/shardingsphere-proxy/shardingsphere-proxy-frontend/shardingsphere-proxy-frontend-reactive-mysql/src/main/java/org/apache/shardingsphere/proxy/frontend/reactive/mysql/command/query/binary/execute/ReactiveMySQLComStmtExecuteExecutor.java
+++ b/shardingsphere-proxy/shardingsphere-proxy-frontend/shardingsphere-proxy-frontend-reactive-mysql/src/main/java/org/apache/shardingsphere/proxy/frontend/reactive/mysql/command/query/binary/execute/ReactiveMySQLComStmtExecuteExecutor.java
@@ -17,14 +17,18 @@
package org.apache.shardingsphere.proxy.frontend.reactive.mysql.command.query.binary.execute;
-import com.google.common.base.Preconditions;
import io.vertx.core.Future;
import lombok.Getter;
+import lombok.RequiredArgsConstructor;
+import lombok.SneakyThrows;
import org.apache.shardingsphere.db.protocol.binary.BinaryCell;
import org.apache.shardingsphere.db.protocol.binary.BinaryRow;
import org.apache.shardingsphere.db.protocol.mysql.constant.MySQLBinaryColumnType;
import org.apache.shardingsphere.db.protocol.mysql.constant.MySQLConstants;
+import org.apache.shardingsphere.db.protocol.mysql.constant.MySQLNewParametersBoundFlag;
import org.apache.shardingsphere.db.protocol.mysql.packet.MySQLPacket;
+import org.apache.shardingsphere.db.protocol.mysql.packet.command.query.binary.MySQLPreparedStatement;
+import org.apache.shardingsphere.db.protocol.mysql.packet.command.query.binary.MySQLPreparedStatementRegistry;
import org.apache.shardingsphere.db.protocol.mysql.packet.command.query.binary.execute.MySQLBinaryResultSetRowPacket;
import org.apache.shardingsphere.db.protocol.mysql.packet.command.query.binary.execute.MySQLComStmtExecutePacket;
import org.apache.shardingsphere.db.protocol.mysql.packet.generic.MySQLEofPacket;
@@ -32,12 +36,10 @@ import org.apache.shardingsphere.db.protocol.packet.DatabasePacket;
import org.apache.shardingsphere.infra.binder.SQLStatementContextFactory;
import org.apache.shardingsphere.infra.binder.statement.SQLStatementContext;
import org.apache.shardingsphere.infra.binder.type.TableAvailable;
-import org.apache.shardingsphere.infra.database.type.DatabaseTypeEngine;
import org.apache.shardingsphere.infra.database.type.DatabaseTypeFactory;
import org.apache.shardingsphere.infra.executor.check.SQLCheckEngine;
import org.apache.shardingsphere.infra.rule.ShardingSphereRule;
import org.apache.shardingsphere.mode.metadata.MetaDataContexts;
-import org.apache.shardingsphere.parser.rule.SQLParserRule;
import org.apache.shardingsphere.proxy.backend.communication.DatabaseCommunicationEngineFactory;
import org.apache.shardingsphere.proxy.backend.communication.SQLStatementDatabaseHolder;
import org.apache.shardingsphere.proxy.backend.communication.vertx.VertxDatabaseCommunicationEngine;
@@ -67,57 +69,48 @@ import java.util.Optional;
/**
* Reactive COM_STMT_EXECUTE command executor for MySQL.
*/
+@RequiredArgsConstructor
public final class ReactiveMySQLComStmtExecuteExecutor implements ReactiveCommandExecutor {
- private final VertxDatabaseCommunicationEngine databaseCommunicationEngine;
+ private final MySQLComStmtExecutePacket packet;
- private final TextProtocolBackendHandler textProtocolBackendHandler;
+ private final ConnectionSession connectionSession;
- private final int characterSet;
+ private VertxDatabaseCommunicationEngine databaseCommunicationEngine;
+
+ private TextProtocolBackendHandler textProtocolBackendHandler;
@Getter
- private volatile ResponseType responseType;
+ private ResponseType responseType;
private int currentSequenceId;
- public ReactiveMySQLComStmtExecuteExecutor(final MySQLComStmtExecutePacket packet, final ConnectionSession connectionSession) throws SQLException {
+ @SneakyThrows(SQLException.class)
+ @Override
+ public Future<Collection<DatabasePacket<?>>> executeFuture() {
+ MySQLPreparedStatement preparedStatement = updateAndGetPreparedStatement();
String databaseName = connectionSession.getDatabaseName();
MetaDataContexts metaDataContexts = ProxyContext.getInstance().getContextManager().getMetaDataContexts();
- Optional<SQLParserRule> sqlParserRule = metaDataContexts.getMetaData().getGlobalRuleMetaData().findSingleRule(SQLParserRule.class);
- Preconditions.checkState(sqlParserRule.isPresent());
- SQLStatement sqlStatement = sqlParserRule.get().getSQLParserEngine(
- DatabaseTypeEngine.getTrunkDatabaseTypeName(metaDataContexts.getMetaData().getDatabases().get(databaseName).getProtocolType())).parse(packet.getSql(), true);
- SQLStatementContext<?> sqlStatementContext = SQLStatementContextFactory.newInstance(metaDataContexts.getMetaData().getDatabases(), packet.getParameters(),
+ SQLStatement sqlStatement = preparedStatement.getSqlStatement();
+ List<Object> parameters = packet.readParameters(preparedStatement.getParameterTypes());
+ SQLStatementContext<?> sqlStatementContext = SQLStatementContextFactory.newInstance(metaDataContexts.getMetaData().getDatabases(), parameters,
sqlStatement, connectionSession.getDefaultDatabaseName());
// TODO optimize SQLStatementDatabaseHolder
if (sqlStatementContext instanceof TableAvailable) {
((TableAvailable) sqlStatementContext).getTablesContext().getDatabaseName().ifPresent(SQLStatementDatabaseHolder::set);
}
SQLCheckEngine.check(sqlStatement, Collections.emptyList(), getRules(databaseName), databaseName, metaDataContexts.getMetaData().getDatabases(), connectionSession.getGrantee());
- characterSet = connectionSession.getAttributeMap().attr(MySQLConstants.MYSQL_CHARACTER_SET_ATTRIBUTE_KEY).get().getId();
+ int characterSet = connectionSession.getAttributeMap().attr(MySQLConstants.MYSQL_CHARACTER_SET_ATTRIBUTE_KEY).get().getId();
// TODO Refactor the following branch
if (sqlStatement instanceof TCLStatement) {
- databaseCommunicationEngine = null;
textProtocolBackendHandler =
- TextProtocolBackendHandlerFactory.newInstance(DatabaseTypeFactory.getInstance("MySQL"), packet.getSql(), () -> Optional.of(sqlStatement), connectionSession);
- return;
+ TextProtocolBackendHandlerFactory.newInstance(DatabaseTypeFactory.getInstance("MySQL"), preparedStatement.getSql(), () -> Optional.of(sqlStatement), connectionSession);
+ } else {
+ databaseCommunicationEngine = DatabaseCommunicationEngineFactory.getInstance()
+ .newBinaryProtocolInstance(sqlStatementContext, preparedStatement.getSql(), parameters, connectionSession.getBackendConnection());
}
- textProtocolBackendHandler = null;
- databaseCommunicationEngine = DatabaseCommunicationEngineFactory.getInstance()
- .newBinaryProtocolInstance(sqlStatementContext, packet.getSql(), packet.getParameters(), connectionSession.getBackendConnection());
- }
-
- private static Collection<ShardingSphereRule> getRules(final String databaseName) {
- Collection<ShardingSphereRule> result;
- result = new LinkedList<>(ProxyContext.getInstance().getContextManager().getMetaDataContexts().getMetaData().getDatabases().get(databaseName).getRuleMetaData().getRules());
- result.addAll(ProxyContext.getInstance().getContextManager().getMetaDataContexts().getMetaData().getGlobalRuleMetaData().getRules());
- return result;
- }
-
- @Override
- public Future<Collection<DatabasePacket<?>>> executeFuture() {
return (null != databaseCommunicationEngine ? databaseCommunicationEngine.execute() : textProtocolBackendHandler.executeFuture()).compose(responseHeader -> {
- Collection<DatabasePacket<?>> headerPackets = responseHeader instanceof QueryResponseHeader ? processQuery((QueryResponseHeader) responseHeader)
+ Collection<DatabasePacket<?>> headerPackets = responseHeader instanceof QueryResponseHeader ? processQuery((QueryResponseHeader) responseHeader, characterSet)
: processUpdate((UpdateResponseHeader) responseHeader);
List<DatabasePacket<?>> result = new LinkedList<>(headerPackets);
if (ResponseType.UPDATE == responseType) {
@@ -135,7 +128,22 @@ public final class ReactiveMySQLComStmtExecuteExecutor implements ReactiveComman
});
}
- private Collection<DatabasePacket<?>> processQuery(final QueryResponseHeader queryResponseHeader) {
+ private MySQLPreparedStatement updateAndGetPreparedStatement() {
+ MySQLPreparedStatement result = MySQLPreparedStatementRegistry.getInstance().getConnectionPreparedStatements(connectionSession.getConnectionId()).get(packet.getStatementId());
+ if (MySQLNewParametersBoundFlag.PARAMETER_TYPE_EXIST == packet.getNewParametersBoundFlag()) {
+ result.setParameterTypes(packet.getNewParameterTypes());
+ }
+ return result;
+ }
+
+ private static Collection<ShardingSphereRule> getRules(final String databaseName) {
+ Collection<ShardingSphereRule> result;
+ result = new LinkedList<>(ProxyContext.getInstance().getContextManager().getMetaDataContexts().getMetaData().getDatabases().get(databaseName).getRuleMetaData().getRules());
+ result.addAll(ProxyContext.getInstance().getContextManager().getMetaDataContexts().getMetaData().getGlobalRuleMetaData().getRules());
+ return result;
+ }
+
+ private Collection<DatabasePacket<?>> processQuery(final QueryResponseHeader queryResponseHeader, final int characterSet) {
responseType = ResponseType.QUERY;
Collection<DatabasePacket<?>> result = ResponsePacketBuilder.buildQueryResponsePackets(queryResponseHeader, characterSet);
currentSequenceId = result.size();