You are viewing a plain text version of this content. The canonical link for it is here.
Posted to notifications@shardingsphere.apache.org by pa...@apache.org on 2021/01/28 11:21:51 UTC
[shardingsphere] branch master updated: Simplify
ResultSetAdapterTest (#9206)
This is an automated email from the ASF dual-hosted git repository.
panjuan pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/shardingsphere.git
The following commit(s) were added to refs/heads/master by this push:
new 1938c79 Simplify ResultSetAdapterTest (#9206)
1938c79 is described below
commit 1938c799a1e7b38971f20acc6297e714a91ef940
Author: Liang Zhang <te...@163.com>
AuthorDate: Thu Jan 28 19:21:18 2021 +0800
Simplify ResultSetAdapterTest (#9206)
* Simplify ResultSetAdapterTest
* Simplify StatementAdapterTest
---
.../jdbc/adapter/AbstractStatementAdapter.java | 11 +-
.../driver/jdbc/adapter/ResultSetAdapterTest.java | 174 ++++-------
.../driver/jdbc/adapter/StatementAdapterTest.java | 334 ++++++++-------------
3 files changed, 179 insertions(+), 340 deletions(-)
diff --git a/shardingsphere-jdbc/shardingsphere-jdbc-core/src/main/java/org/apache/shardingsphere/driver/jdbc/adapter/AbstractStatementAdapter.java b/shardingsphere-jdbc/shardingsphere-jdbc-core/src/main/java/org/apache/shardingsphere/driver/jdbc/adapter/AbstractStatementAdapter.java
index b0f9b83..c7d4aca 100644
--- a/shardingsphere-jdbc/shardingsphere-jdbc-core/src/main/java/org/apache/shardingsphere/driver/jdbc/adapter/AbstractStatementAdapter.java
+++ b/shardingsphere-jdbc/shardingsphere-jdbc-core/src/main/java/org/apache/shardingsphere/driver/jdbc/adapter/AbstractStatementAdapter.java
@@ -124,13 +124,12 @@ public abstract class AbstractStatementAdapter extends AbstractUnsupportedOperat
public final int getUpdateCount() throws SQLException {
if (isAccumulate()) {
return accumulate();
- } else {
- Collection<? extends Statement> statements = getRoutedStatements();
- if (statements.isEmpty()) {
- return -1;
- }
- return getRoutedStatements().iterator().next().getUpdateCount();
}
+ Collection<? extends Statement> statements = getRoutedStatements();
+ if (statements.isEmpty()) {
+ return -1;
+ }
+ return getRoutedStatements().iterator().next().getUpdateCount();
}
private int accumulate() throws SQLException {
diff --git a/shardingsphere-jdbc/shardingsphere-jdbc-core/src/test/java/org/apache/shardingsphere/driver/jdbc/adapter/ResultSetAdapterTest.java b/shardingsphere-jdbc/shardingsphere-jdbc-core/src/test/java/org/apache/shardingsphere/driver/jdbc/adapter/ResultSetAdapterTest.java
index 8d729a1..b808320 100644
--- a/shardingsphere-jdbc/shardingsphere-jdbc-core/src/test/java/org/apache/shardingsphere/driver/jdbc/adapter/ResultSetAdapterTest.java
+++ b/shardingsphere-jdbc/shardingsphere-jdbc-core/src/test/java/org/apache/shardingsphere/driver/jdbc/adapter/ResultSetAdapterTest.java
@@ -17,172 +17,104 @@
package org.apache.shardingsphere.driver.jdbc.adapter;
-import org.apache.shardingsphere.driver.jdbc.base.AbstractShardingSphereDataSourceForShardingTest;
-import org.apache.shardingsphere.infra.database.type.DatabaseTypeRegistry;
-import org.apache.shardingsphere.driver.jdbc.core.connection.ShardingSphereConnection;
-import org.apache.shardingsphere.driver.jdbc.util.JDBCTestSQL;
-import org.apache.shardingsphere.infra.database.type.DatabaseType;
-import org.junit.After;
-import org.junit.Before;
+import org.apache.shardingsphere.driver.jdbc.core.resultset.ShardingSphereResultSet;
+import org.apache.shardingsphere.driver.jdbc.core.statement.ShardingSphereStatement;
+import org.apache.shardingsphere.infra.executor.sql.context.ExecutionContext;
+import org.apache.shardingsphere.infra.merge.result.MergedResult;
import org.junit.Test;
import java.sql.ResultSet;
+import java.sql.ResultSetMetaData;
import java.sql.SQLException;
-import java.sql.Statement;
-import java.util.ArrayList;
-import java.util.HashMap;
-import java.util.List;
-import java.util.Map;
-import java.util.Map.Entry;
+import java.util.Collections;
import static org.hamcrest.CoreMatchers.is;
-import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertNotNull;
-import static org.junit.Assert.assertNull;
import static org.junit.Assert.assertThat;
import static org.junit.Assert.assertTrue;
+import static org.mockito.Mockito.RETURNS_DEEP_STUBS;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.when;
-public final class ResultSetAdapterTest extends AbstractShardingSphereDataSourceForShardingTest {
-
- private final List<ShardingSphereConnection> shardingSphereConnections = new ArrayList<>();
-
- private final List<Statement> statements = new ArrayList<>();
-
- private final Map<DatabaseType, ResultSet> resultSets = new HashMap<>();
-
- @Before
- public void init() throws SQLException {
- ShardingSphereConnection connection = getShardingSphereDataSource().getConnection();
- shardingSphereConnections.add(connection);
- Statement statement = connection.createStatement();
- statements.add(statement);
- resultSets.put(DatabaseTypeRegistry.getActualDatabaseType("H2"), statement.executeQuery(JDBCTestSQL.SELECT_GROUP_BY_USER_ID_SQL));
- }
-
- @After
- public void close() throws SQLException {
- for (ResultSet each : resultSets.values()) {
- each.close();
- }
- for (Statement each : statements) {
- each.close();
- }
- for (ShardingSphereConnection each : shardingSphereConnections) {
- each.close();
- }
- }
+public final class ResultSetAdapterTest {
@Test
public void assertClose() throws SQLException {
- for (Entry<DatabaseType, ResultSet> each : resultSets.entrySet()) {
- each.getValue().close();
- assertClose((AbstractResultSetAdapter) each.getValue(), each.getKey());
- }
- }
-
- private void assertClose(final AbstractResultSetAdapter actual, final DatabaseType type) throws SQLException {
+ ResultSet resultSet = mock(ResultSet.class);
+ ShardingSphereResultSet actual = mockShardingSphereResultSet(resultSet);
+ actual.close();
assertTrue(actual.isClosed());
- assertThat(actual.getResultSets().size(), is(4));
- if (DatabaseTypeRegistry.getActualDatabaseType("Oracle") != type) {
- for (ResultSet each : actual.getResultSets()) {
- assertTrue(each.isClosed());
- }
- }
- }
-
- @Test
- public void assertWasNull() throws SQLException {
- for (ResultSet each : resultSets.values()) {
- assertFalse(each.wasNull());
- }
+ verify(resultSet).close();
}
@Test
public void assertSetFetchDirection() throws SQLException {
- for (Entry<DatabaseType, ResultSet> each : resultSets.entrySet()) {
- assertThat(each.getValue().getFetchDirection(), is(ResultSet.FETCH_FORWARD));
- try {
- each.getValue().setFetchDirection(ResultSet.FETCH_REVERSE);
- } catch (final SQLException ignored) {
- }
- if (each.getKey() == DatabaseTypeRegistry.getActualDatabaseType("MySQL") || each.getKey() == DatabaseTypeRegistry.getActualDatabaseType("PostgreSQL")) {
- assertFetchDirection((AbstractResultSetAdapter) each.getValue(), ResultSet.FETCH_REVERSE, each.getKey());
- }
- }
- }
-
- private void assertFetchDirection(final AbstractResultSetAdapter actual, final int fetchDirection, final DatabaseType type) throws SQLException {
- // H2 do not implement getFetchDirection
- assertThat(actual.getFetchDirection(), is(
- DatabaseTypeRegistry.getActualDatabaseType("H2") == type || DatabaseTypeRegistry.getActualDatabaseType("PostgreSQL") == type ? ResultSet.FETCH_FORWARD : fetchDirection));
- assertThat(actual.getResultSets().size(), is(4));
- for (ResultSet each : actual.getResultSets()) {
- assertThat(each.getFetchDirection(), is(
- DatabaseTypeRegistry.getActualDatabaseType("H2") == type || DatabaseTypeRegistry.getActualDatabaseType("PostgreSQL") == type ? ResultSet.FETCH_FORWARD : fetchDirection));
- }
+ ResultSet resultSet = mock(ResultSet.class);
+ ShardingSphereResultSet actual = mockShardingSphereResultSet(resultSet);
+ actual.setFetchDirection(ResultSet.FETCH_REVERSE);
+ verify(resultSet).setFetchDirection(ResultSet.FETCH_REVERSE);
}
@Test
public void assertSetFetchSize() throws SQLException {
- for (Entry<DatabaseType, ResultSet> each : resultSets.entrySet()) {
- if (DatabaseTypeRegistry.getActualDatabaseType("MySQL") == each.getKey() || DatabaseTypeRegistry.getActualDatabaseType("PostgreSQL") == each.getKey()) {
- assertThat(each.getValue().getFetchSize(), is(0));
- }
- each.getValue().setFetchSize(100);
- assertFetchSize((AbstractResultSetAdapter) each.getValue(), each.getKey());
- }
- }
-
- private void assertFetchSize(final AbstractResultSetAdapter actual, final DatabaseType type) throws SQLException {
- // H2 do not implement getFetchSize
- assertThat(actual.getFetchSize(), is(DatabaseTypeRegistry.getActualDatabaseType("H2") == type ? 0 : 100));
- assertThat(actual.getResultSets().size(), is(4));
- for (ResultSet each : actual.getResultSets()) {
- assertThat(each.getFetchSize(), is(DatabaseTypeRegistry.getActualDatabaseType("H2") == type ? 0 : 100));
- }
+ ResultSet resultSet = mock(ResultSet.class);
+ ShardingSphereResultSet actual = mockShardingSphereResultSet(resultSet);
+ actual.setFetchSize(100);
+ verify(resultSet).setFetchSize(100);
}
@Test
public void assertGetType() throws SQLException {
- for (ResultSet each : resultSets.values()) {
- assertThat(each.getType(), is(ResultSet.TYPE_FORWARD_ONLY));
- }
+ ResultSet resultSet = mock(ResultSet.class);
+ when(resultSet.getType()).thenReturn(ResultSet.TYPE_FORWARD_ONLY);
+ ShardingSphereResultSet actual = mockShardingSphereResultSet(resultSet);
+ assertThat(actual.getType(), is(ResultSet.TYPE_FORWARD_ONLY));
}
@Test
public void assertGetConcurrency() throws SQLException {
- for (ResultSet each : resultSets.values()) {
- assertThat(each.getConcurrency(), is(ResultSet.CONCUR_READ_ONLY));
- }
+ ResultSet resultSet = mock(ResultSet.class);
+ when(resultSet.getConcurrency()).thenReturn(ResultSet.CONCUR_READ_ONLY);
+ ShardingSphereResultSet actual = mockShardingSphereResultSet(resultSet);
+ assertThat(actual.getConcurrency(), is(ResultSet.CONCUR_READ_ONLY));
}
@Test
public void assertGetStatement() throws SQLException {
- for (ResultSet each : resultSets.values()) {
- assertNotNull(each.getStatement());
- }
+ ResultSet resultSet = mock(ResultSet.class);
+ ShardingSphereResultSet actual = mockShardingSphereResultSet(resultSet);
+ assertNotNull(actual.getStatement());
}
@Test
public void assertClearWarnings() throws SQLException {
- for (ResultSet each : resultSets.values()) {
- assertNull(each.getWarnings());
- each.clearWarnings();
- assertNull(each.getWarnings());
- }
+ ResultSet resultSet = mock(ResultSet.class);
+ ShardingSphereResultSet actual = mockShardingSphereResultSet(resultSet);
+ actual.clearWarnings();
+ verify(resultSet).clearWarnings();
}
@Test
public void assertGetMetaData() throws SQLException {
- for (ResultSet each : resultSets.values()) {
- assertNotNull(each.getMetaData());
- }
+ ResultSet resultSet = mock(ResultSet.class);
+ ShardingSphereResultSet actual = mockShardingSphereResultSet(resultSet);
+ assertThat(actual.getMetaData().getColumnLabel(1), is("col"));
}
@Test
public void assertFindColumn() throws SQLException {
- for (Entry<DatabaseType, ResultSet> each : resultSets.entrySet()) {
- assertThat(each.getValue().findColumn("user_id"), is(1));
- }
+ ResultSet resultSet = mock(ResultSet.class);
+ when(resultSet.findColumn("col")).thenReturn(1);
+ ShardingSphereResultSet actual = mockShardingSphereResultSet(resultSet);
+ assertThat(actual.findColumn("col"), is(1));
+ }
+
+ private ShardingSphereResultSet mockShardingSphereResultSet(final ResultSet resultSet) throws SQLException {
+ ResultSetMetaData resultSetMetaData = mock(ResultSetMetaData.class);
+ when(resultSetMetaData.getColumnLabel(1)).thenReturn("col");
+ when(resultSetMetaData.getColumnCount()).thenReturn(1);
+ when(resultSet.getMetaData()).thenReturn(resultSetMetaData);
+ return new ShardingSphereResultSet(Collections.singletonList(resultSet), mock(MergedResult.class), mock(ShardingSphereStatement.class, RETURNS_DEEP_STUBS), mock(ExecutionContext.class));
}
}
diff --git a/shardingsphere-jdbc/shardingsphere-jdbc-core/src/test/java/org/apache/shardingsphere/driver/jdbc/adapter/StatementAdapterTest.java b/shardingsphere-jdbc/shardingsphere-jdbc-core/src/test/java/org/apache/shardingsphere/driver/jdbc/adapter/StatementAdapterTest.java
index 66c3a87..4de7bdb 100644
--- a/shardingsphere-jdbc/shardingsphere-jdbc-core/src/test/java/org/apache/shardingsphere/driver/jdbc/adapter/StatementAdapterTest.java
+++ b/shardingsphere-jdbc/shardingsphere-jdbc-core/src/test/java/org/apache/shardingsphere/driver/jdbc/adapter/StatementAdapterTest.java
@@ -17,320 +17,228 @@
package org.apache.shardingsphere.driver.jdbc.adapter;
-import com.google.common.collect.Lists;
-import org.apache.shardingsphere.driver.jdbc.base.AbstractShardingSphereDataSourceForShardingTest;
+import lombok.SneakyThrows;
import org.apache.shardingsphere.driver.jdbc.core.connection.ShardingSphereConnection;
-import org.apache.shardingsphere.driver.jdbc.core.statement.ShardingSpherePreparedStatement;
import org.apache.shardingsphere.driver.jdbc.core.statement.ShardingSphereStatement;
-import org.apache.shardingsphere.driver.jdbc.util.JDBCTestSQL;
-import org.apache.shardingsphere.infra.database.type.DatabaseType;
-import org.apache.shardingsphere.infra.database.type.DatabaseTypeRegistry;
-import org.junit.After;
-import org.junit.Before;
+import org.apache.shardingsphere.infra.executor.sql.context.ExecutionContext;
+import org.apache.shardingsphere.infra.rule.type.DataNodeContainedRule;
import org.junit.Test;
+import java.lang.reflect.Field;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.sql.Statement;
-import java.util.ArrayList;
-import java.util.HashMap;
-import java.util.List;
-import java.util.Map;
-import java.util.Map.Entry;
+import java.util.Arrays;
+import java.util.Collections;
import static org.hamcrest.CoreMatchers.is;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertNull;
import static org.junit.Assert.assertThat;
import static org.junit.Assert.assertTrue;
-import static org.mockito.Mockito.doReturn;
+import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.Mockito.RETURNS_DEEP_STUBS;
import static org.mockito.Mockito.mock;
-import static org.mockito.Mockito.spy;
+import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
-public final class StatementAdapterTest extends AbstractShardingSphereDataSourceForShardingTest {
-
- private final List<ShardingSphereConnection> shardingSphereConnections = new ArrayList<>();
-
- private final Map<DatabaseType, Statement> statements = new HashMap<>();
-
- private final String sql = JDBCTestSQL.SELECT_GROUP_BY_USER_ID_SQL;
-
- @Before
- public void init() {
- ShardingSphereConnection connection = getShardingSphereDataSource().getConnection();
- shardingSphereConnections.add(connection);
- statements.put(DatabaseTypeRegistry.getActualDatabaseType("H2"), connection.createStatement());
- }
-
- @After
- public void close() throws SQLException {
- for (Statement each : statements.values()) {
- each.close();
- }
- for (ShardingSphereConnection each : shardingSphereConnections) {
- each.close();
- }
- }
+public final class StatementAdapterTest {
@Test
public void assertClose() throws SQLException {
- for (Statement each : statements.values()) {
- each.executeQuery(sql);
- each.close();
- assertTrue(each.isClosed());
- assertTrue(((ShardingSphereStatement) each).getRoutedStatements().isEmpty());
- }
+ Statement statement = mock(Statement.class);
+ ShardingSphereStatement actual = mockShardingSphereStatement(statement);
+ actual.close();
+ assertTrue(actual.isClosed());
+ assertTrue(actual.getRoutedStatements().isEmpty());
+ verify(statement).close();
}
@Test
public void assertSetPoolable() throws SQLException {
- for (Entry<DatabaseType, Statement> each : statements.entrySet()) {
- each.getValue().setPoolable(true);
- each.getValue().executeQuery(sql);
- assertPoolable((ShardingSphereStatement) each.getValue(), true);
- each.getValue().setPoolable(false);
- assertPoolable((ShardingSphereStatement) each.getValue(), false);
- }
- }
-
- private void assertPoolable(final ShardingSphereStatement actual, final boolean poolable) throws SQLException {
- assertThat(actual.isPoolable(), is(poolable));
- assertThat(actual.getRoutedStatements().size(), is(4));
- for (Statement each : actual.getRoutedStatements()) {
- // H2 do not implements method `setPoolable()`
- assertFalse(each.isPoolable());
- }
+ Statement statement = mock(Statement.class);
+ ShardingSphereStatement actual = mockShardingSphereStatement(statement);
+ actual.setPoolable(true);
+ assertTrue(actual.isPoolable());
+ verify(statement).setPoolable(true);
}
@Test
public void assertSetFetchSize() throws SQLException {
- for (Statement each : statements.values()) {
- each.setFetchSize(4);
- each.executeQuery(sql);
- assertFetchSize((ShardingSphereStatement) each, 4);
- each.setFetchSize(100);
- assertFetchSize((ShardingSphereStatement) each, 100);
- }
+ Statement statement = mock(Statement.class);
+ ShardingSphereStatement actual = mockShardingSphereStatement(statement);
+ actual.setFetchSize(100);
+ assertThat(actual.getFetchSize(), is(100));
+ verify(statement).setFetchSize(100);
}
- private void assertFetchSize(final ShardingSphereStatement actual, final int fetchSize) throws SQLException {
- assertThat(actual.getFetchSize(), is(fetchSize));
- assertThat(actual.getRoutedStatements().size(), is(4));
- for (Statement each : actual.getRoutedStatements()) {
- assertThat(each.getFetchSize(), is(fetchSize));
- }
- }
-
@Test
public void assertSetFetchDirection() throws SQLException {
- for (Statement each : statements.values()) {
- each.setFetchDirection(ResultSet.FETCH_FORWARD);
- each.executeQuery(sql);
- assertFetchDirection((ShardingSphereStatement) each, ResultSet.FETCH_FORWARD);
- each.setFetchDirection(ResultSet.FETCH_REVERSE);
- assertFetchDirection((ShardingSphereStatement) each, ResultSet.FETCH_REVERSE);
- }
- }
-
- private void assertFetchDirection(final ShardingSphereStatement actual, final int fetchDirection) throws SQLException {
- assertThat(actual.getFetchDirection(), is(fetchDirection));
- for (Statement each : actual.getRoutedStatements()) {
- // H2,MySQL getFetchDirection() always return ResultSet.FETCH_FORWARD
- assertThat(each.getFetchDirection(), is(ResultSet.FETCH_FORWARD));
- }
+ Statement statement = mock(Statement.class);
+ ShardingSphereStatement actual = mockShardingSphereStatement(statement);
+ actual.setFetchDirection(ResultSet.FETCH_REVERSE);
+ assertThat(actual.getFetchDirection(), is(ResultSet.FETCH_REVERSE));
+ verify(statement).setFetchDirection(ResultSet.FETCH_REVERSE);
}
-
+
@Test
public void assertSetEscapeProcessing() throws SQLException {
- for (Statement each : statements.values()) {
- each.setEscapeProcessing(true);
- each.executeQuery(sql);
- each.setEscapeProcessing(false);
- }
+ Statement statement = mock(Statement.class);
+ ShardingSphereStatement actual = mockShardingSphereStatement(statement);
+ actual.setEscapeProcessing(true);
+ verify(statement).setEscapeProcessing(true);
}
@Test
public void assertCancel() throws SQLException {
- for (Statement each : statements.values()) {
- each.executeQuery(sql);
- each.cancel();
- }
+ Statement statement = mock(Statement.class);
+ ShardingSphereStatement actual = mockShardingSphereStatement(statement);
+ actual.cancel();
+ verify(statement).cancel();
}
@Test
- public void assertGetUpdateCount() throws SQLException {
- String sql = "DELETE FROM t_order WHERE status = 'init'";
- for (Entry<DatabaseType, Statement> each : statements.entrySet()) {
- each.getValue().execute(sql);
- assertThat(each.getValue().getUpdateCount(), is(4));
- }
+ public void assertGetUpdateCountWithoutAccumulate() throws SQLException {
+ Statement statement1 = mock(Statement.class);
+ when(statement1.getUpdateCount()).thenReturn(Integer.MAX_VALUE);
+ Statement statement2 = mock(Statement.class);
+ when(statement2.getUpdateCount()).thenReturn(Integer.MAX_VALUE);
+ ShardingSphereStatement actual = mockShardingSphereStatement(statement1, statement2);
+ assertThat(actual.getUpdateCount(), is(Integer.MAX_VALUE));
}
@Test
- public void assertGetUpdateCountNoData() throws SQLException {
- String sql = "DELETE FROM t_order WHERE status = 'none'";
- for (Entry<DatabaseType, Statement> each : statements.entrySet()) {
- each.getValue().execute(sql);
- assertThat(each.getValue().getUpdateCount(), is(0));
- }
+ public void assertGetUpdateCountWithoutAccumulateAndInvalidResult() throws SQLException {
+ Statement statement = mock(Statement.class);
+ when(statement.getUpdateCount()).thenReturn(-1);
+ ShardingSphereStatement actual = mockShardingSphereStatement(statement);
+ assertThat(actual.getUpdateCount(), is(-1));
}
@Test
- public void assertGetUpdateCountSelect() throws SQLException {
- for (Statement each : statements.values()) {
- each.execute(sql);
- assertThat(each.getUpdateCount(), is(-1));
- }
+ public void assertGetUpdateCountWithoutAccumulateAndEmptyResult() throws SQLException {
+ ShardingSphereStatement actual = mockShardingSphereStatement();
+ assertThat(actual.getUpdateCount(), is(-1));
}
@Test
- public void assertOverMaxUpdateRow() throws SQLException {
+ public void assertGetUpdateCountWithAccumulate() throws SQLException {
Statement statement1 = mock(Statement.class);
when(statement1.getUpdateCount()).thenReturn(Integer.MAX_VALUE);
Statement statement2 = mock(Statement.class);
when(statement2.getUpdateCount()).thenReturn(Integer.MAX_VALUE);
- ShardingSphereStatement shardingSphereStatement1 = spy(new ShardingSphereStatement(getShardingSphereDataSource().getConnection()));
- doReturn(true).when(shardingSphereStatement1).isAccumulate();
- doReturn(Lists.newArrayList(statement1, statement2)).when(shardingSphereStatement1).getRoutedStatements();
- assertThat(shardingSphereStatement1.getUpdateCount(), is(Integer.MAX_VALUE));
- ShardingSpherePreparedStatement shardingSphereStatement2 = spy(new ShardingSpherePreparedStatement(getShardingSphereDataSource().getConnection(), sql));
- doReturn(true).when(shardingSphereStatement2).isAccumulate();
- doReturn(Lists.newArrayList(statement1, statement2)).when(shardingSphereStatement2).getRoutedStatements();
- assertThat(shardingSphereStatement2.getUpdateCount(), is(Integer.MAX_VALUE));
- }
-
- @Test
- public void assertNotAccumulateUpdateRow() throws SQLException {
- Statement statement1 = mock(Statement.class);
- when(statement1.getUpdateCount()).thenReturn(10);
- Statement statement2 = mock(Statement.class);
- when(statement2.getUpdateCount()).thenReturn(10);
- ShardingSphereStatement shardingSphereStatement1 = spy(new ShardingSphereStatement(getShardingSphereDataSource().getConnection()));
- doReturn(false).when(shardingSphereStatement1).isAccumulate();
- doReturn(Lists.newArrayList(statement1, statement2)).when(shardingSphereStatement1).getRoutedStatements();
- assertThat(shardingSphereStatement1.getUpdateCount(), is(10));
- ShardingSpherePreparedStatement shardingSphereStatement2 = spy(new ShardingSpherePreparedStatement(getShardingSphereDataSource().getConnection(), sql));
- doReturn(false).when(shardingSphereStatement2).isAccumulate();
- doReturn(Lists.newArrayList(statement1, statement2)).when(shardingSphereStatement2).getRoutedStatements();
- assertThat(shardingSphereStatement2.getUpdateCount(), is(10));
+ ShardingSphereStatement actual = mockShardingSphereStatementWithNeedAccumulate(statement1, statement2);
+ assertThat(actual.getUpdateCount(), is(Integer.MAX_VALUE));
}
@Test
- public void assertGetWarnings() throws SQLException {
- for (Statement each : statements.values()) {
- assertNull(each.getWarnings());
- }
+ public void assertGetWarnings() {
+ assertNull(mockShardingSphereStatement().getWarnings());
}
@Test
- public void assertClearWarnings() throws SQLException {
- for (Statement each : statements.values()) {
- each.clearWarnings();
- }
+ public void assertClearWarnings() {
+ mockShardingSphereStatement().clearWarnings();
}
@Test
public void assertGetMoreResults() throws SQLException {
- for (Statement each : statements.values()) {
- assertFalse(each.getMoreResults());
- }
+ Statement statement = mock(Statement.class);
+ when(statement.getMoreResults()).thenReturn(true);
+ ShardingSphereStatement actual = mockShardingSphereStatement(statement);
+ assertTrue(actual.getMoreResults());
}
@Test
- public void assertGetMoreResultsWithCurrent() throws SQLException {
- for (Statement each : statements.values()) {
- assertFalse(each.getMoreResults(Statement.KEEP_CURRENT_RESULT));
- }
+ public void assertGetMoreResultsWithCurrent() {
+ assertFalse(mockShardingSphereStatement().getMoreResults(Statement.KEEP_CURRENT_RESULT));
}
@Test
public void assertGetMaxFieldSizeWithoutRoutedStatements() throws SQLException {
- for (Statement each : statements.values()) {
- assertThat(each.getMaxFieldSize(), is(0));
- }
+ assertThat(mockShardingSphereStatement().getMaxFieldSize(), is(0));
}
@Test
public void assertGetMaxFieldSizeWithRoutedStatements() throws SQLException {
- for (Statement each : statements.values()) {
- each.executeQuery(sql);
- assertTrue(each.getMaxFieldSize() > -1);
- }
+ Statement statement = mock(Statement.class);
+ when(statement.getMaxFieldSize()).thenReturn(10);
+ ShardingSphereStatement actual = mockShardingSphereStatement(statement);
+ assertThat(actual.getMaxFieldSize(), is(10));
}
@Test
public void assertSetMaxFieldSize() throws SQLException {
- for (Entry<DatabaseType, Statement> each : statements.entrySet()) {
- each.getValue().executeQuery(sql);
- each.getValue().setMaxFieldSize(10);
- assertThat(each.getValue().getMaxFieldSize(), is(DatabaseTypeRegistry.getActualDatabaseType("H2") == each.getKey() ? 0 : 10));
- }
+ Statement statement = mock(Statement.class);
+ ShardingSphereStatement actual = mockShardingSphereStatement(statement);
+ actual.setMaxFieldSize(10);
+ verify(statement).setMaxFieldSize(10);
}
@Test
public void assertGetMaxRowsWitRoutedStatements() throws SQLException {
- for (Statement each : statements.values()) {
- assertThat(each.getMaxRows(), is(-1));
- }
+ assertThat(mockShardingSphereStatement().getMaxRows(), is(-1));
}
@Test
public void assertGetMaxRowsWithoutRoutedStatements() throws SQLException {
- for (Statement each : statements.values()) {
- each.executeQuery(sql);
- assertThat(each.getMaxRows(), is(0));
- }
+ Statement statement = mock(Statement.class);
+ when(statement.getMaxRows()).thenReturn(10);
+ ShardingSphereStatement actual = mockShardingSphereStatement(statement);
+ assertThat(actual.getMaxRows(), is(10));
}
@Test
public void assertSetMaxRows() throws SQLException {
- for (Statement each : statements.values()) {
- each.executeQuery(sql);
- each.setMaxRows(10);
- assertThat(each.getMaxRows(), is(10));
- }
+ Statement statement = mock(Statement.class);
+ ShardingSphereStatement actual = mockShardingSphereStatement(statement);
+ actual.setMaxRows(10);
+ verify(statement).setMaxRows(10);
}
@Test
public void assertGetQueryTimeoutWithoutRoutedStatements() throws SQLException {
- for (Statement each : statements.values()) {
- assertThat(each.getQueryTimeout(), is(0));
- }
+ assertThat(mockShardingSphereStatement().getQueryTimeout(), is(0));
}
-
+
@Test
public void assertGetQueryTimeoutWithRoutedStatements() throws SQLException {
- for (Statement each : statements.values()) {
- each.executeQuery(sql);
- assertThat(each.getQueryTimeout(), is(0));
- }
+ Statement statement = mock(Statement.class);
+ when(statement.getQueryTimeout()).thenReturn(10);
+ ShardingSphereStatement actual = mockShardingSphereStatement(statement);
+ assertThat(actual.getQueryTimeout(), is(10));
}
@Test
public void assertSetQueryTimeout() throws SQLException {
- for (Entry<DatabaseType, Statement> each : statements.entrySet()) {
- each.getValue().executeQuery(sql);
- each.getValue().setQueryTimeout(10);
- assertThat(each.getValue().getQueryTimeout(), is(10));
- }
- }
-
- @Test
- public void assertGetGeneratedKeysForSingleRoutedStatement() throws SQLException {
- for (Statement each : statements.values()) {
- each.execute("INSERT INTO t_order_item (user_id, order_id, status) VALUES (1, 1, 'init')", Statement.RETURN_GENERATED_KEYS);
- ResultSet generatedKeysResult = each.getGeneratedKeys();
- assertTrue(generatedKeysResult.next());
- assertTrue(generatedKeysResult.getInt(1) > 0);
- }
- }
-
- @Test
- public void assertGetGeneratedKeysForMultipleRoutedStatement() throws SQLException {
- for (Statement each : statements.values()) {
- each.executeQuery("SELECT user_id AS usr_id FROM t_order WHERE order_id IN (1, 2)");
- assertFalse(each.getGeneratedKeys().next());
- }
+ Statement statement = mock(Statement.class);
+ ShardingSphereStatement actual = mockShardingSphereStatement(statement);
+ actual.setQueryTimeout(10);
+ verify(statement).setQueryTimeout(10);
+ }
+
+ private ShardingSphereStatement mockShardingSphereStatement(final Statement... statements) {
+ ShardingSphereStatement result = new ShardingSphereStatement(mock(ShardingSphereConnection.class, RETURNS_DEEP_STUBS));
+ result.getRoutedStatements().addAll(Arrays.asList(statements));
+ return result;
+ }
+
+ private ShardingSphereStatement mockShardingSphereStatementWithNeedAccumulate(final Statement... statements) {
+ ShardingSphereConnection connection = mock(ShardingSphereConnection.class, RETURNS_DEEP_STUBS);
+ DataNodeContainedRule rule = mock(DataNodeContainedRule.class);
+ when(rule.isNeedAccumulate(any())).thenReturn(true);
+ when(connection.getMetaDataContexts().getDefaultMetaData().getRuleMetaData().getRules()).thenReturn(Collections.singletonList(rule));
+ ShardingSphereStatement result = new ShardingSphereStatement(connection);
+ result.getRoutedStatements().addAll(Arrays.asList(statements));
+ ExecutionContext executionContext = mock(ExecutionContext.class, RETURNS_DEEP_STUBS);
+ setExecutionContext(result, executionContext);
+ return result;
+ }
+
+ @SneakyThrows(ReflectiveOperationException.class)
+ private void setExecutionContext(final ShardingSphereStatement statement, final ExecutionContext executionContext) {
+ Field field = statement.getClass().getDeclaredField("executionContext");
+ field.setAccessible(true);
+ field.set(statement, executionContext);
}
}