You are viewing a plain text version of this content. The canonical link for it is here.
Posted to notifications@shardingsphere.apache.org by du...@apache.org on 2023/06/25 02:14:06 UTC
[shardingsphere] branch master updated: Remove hint value from CommonSQLStatementContext. (#26485)
This is an automated email from the ASF dual-hosted git repository.
duanzhengqiang pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/shardingsphere.git
The following commit(s) were added to refs/heads/master by this push:
new 861d6eaf485 Remove hint value from CommonSQLStatementContext. (#26485)
861d6eaf485 is described below
commit 861d6eaf4851b2c898d6bbea6c0c286bf1c3675e
Author: Chuxin Chen <ch...@qq.com>
AuthorDate: Sun Jun 25 10:13:59 2023 +0800
Remove hint value from CommonSQLStatementContext. (#26485)
---
.../query/text/query/MySQLComQueryPacket.java | 4 +--
.../route/ReadwriteSplittingDataSourceRouter.java | 10 ++++---
.../route/ReadwriteSplittingSQLRouter.java | 9 +++---
...ualifiedReadwriteSplittingDataSourceRouter.java | 5 +++-
...dReadwriteSplittingPrimaryDataSourceRouter.java | 14 +++++-----
...riteSplittingTransactionalDataSourceRouter.java | 3 +-
.../route/ReadwriteSplittingSQLRouterTest.java | 6 ++--
...dwriteSplittingPrimaryDataSourceRouterTest.java | 16 +++++++----
...SplittingTransactionalDataSourceRouterTest.java | 9 ++++--
.../sharding/auditor/ShardingSQLAuditor.java | 9 ++----
.../sharding/auditor/ShardingSQLAuditorTest.java | 12 +++++---
.../cache/route/CachedShardingSQLRouterTest.java | 16 +++++++----
.../statement/CommonSQLStatementContext.java | 32 ----------------------
.../infra/hint/HintValueContext.java | 21 ++++++++++++++
.../infra/hint/SQLHintExtractor.java | 15 ++--------
.../shardingsphere/infra/hint/SQLHintUtils.java | 9 +++---
.../infra/hint/SQLHintExtractorTest.java | 13 +++------
.../infra/hint/SQLHintUtilsTest.java | 8 ++++++
.../infra/connection/kernel/KernelProcessor.java | 6 ++--
.../infra/executor/audit/SQLAuditEngine.java | 6 ++--
.../infra/executor/audit/SQLAuditor.java | 5 +++-
.../infra/rewrite/SQLRewriteEntry.java | 21 ++++++++------
.../infra/rewrite/context/SQLRewriteContext.java | 10 +++----
.../infra/rewrite/SQLRewriteEntryTest.java | 7 +++--
.../rewrite/context/SQLRewriteContextTest.java | 22 +++++++++------
.../engine/GenericSQLRewriteEngineTest.java | 9 ++++--
.../rewrite/engine/RouteSQLRewriteEngineTest.java | 27 +++++++++---------
.../route/engine/impl/PartialSQLRouteExecutor.java | 13 ++++-----
.../engine/impl/PartialSQLRouteExecutorTest.java | 12 +++++---
.../infra/session/query/QueryContext.java | 6 +++-
.../statement/ShardingSpherePreparedStatement.java | 4 +--
.../core/statement/ShardingSphereStatement.java | 4 +--
.../shardingsphere/traffic/rule/TrafficRule.java | 9 ++----
.../traffic/rule/TrafficRuleTest.java | 4 +--
.../handler/ProxyBackendHandlerFactory.java | 6 ++--
.../text/query/MySQLMultiStatementsHandler.java | 4 +--
.../PostgreSQLBatchedStatementsExecutor.java | 4 +--
.../test/it/rewrite/engine/SQLRewriterIT.java | 3 +-
38 files changed, 208 insertions(+), 185 deletions(-)
diff --git a/db-protocol/mysql/src/main/java/org/apache/shardingsphere/db/protocol/mysql/packet/command/query/text/query/MySQLComQueryPacket.java b/db-protocol/mysql/src/main/java/org/apache/shardingsphere/db/protocol/mysql/packet/command/query/text/query/MySQLComQueryPacket.java
index 5d647ea752f..65fa98eccc6 100644
--- a/db-protocol/mysql/src/main/java/org/apache/shardingsphere/db/protocol/mysql/packet/command/query/text/query/MySQLComQueryPacket.java
+++ b/db-protocol/mysql/src/main/java/org/apache/shardingsphere/db/protocol/mysql/packet/command/query/text/query/MySQLComQueryPacket.java
@@ -39,14 +39,14 @@ public final class MySQLComQueryPacket extends MySQLCommandPacket implements SQL
public MySQLComQueryPacket(final String sql, final boolean sqlCommentParseEnabled) {
super(MySQLCommandPacketType.COM_QUERY);
- hintValueContext = sqlCommentParseEnabled ? new HintValueContext() : SQLHintUtils.extractHint(sql);
+ hintValueContext = sqlCommentParseEnabled ? new HintValueContext() : SQLHintUtils.extractHint(sql).orElseGet(HintValueContext::new);
this.sql = sqlCommentParseEnabled ? sql : SQLHintUtils.removeHint(sql);
}
public MySQLComQueryPacket(final MySQLPacketPayload payload, final boolean sqlCommentParseEnabled) {
super(MySQLCommandPacketType.COM_QUERY);
String originSQL = payload.readStringEOF();
- hintValueContext = sqlCommentParseEnabled ? new HintValueContext() : SQLHintUtils.extractHint(originSQL);
+ hintValueContext = sqlCommentParseEnabled ? new HintValueContext() : SQLHintUtils.extractHint(originSQL).orElseGet(HintValueContext::new);
sql = sqlCommentParseEnabled ? originSQL : SQLHintUtils.removeHint(originSQL);
}
diff --git a/features/readwrite-splitting/core/src/main/java/org/apache/shardingsphere/readwritesplitting/route/ReadwriteSplittingDataSourceRouter.java b/features/readwrite-splitting/core/src/main/java/org/apache/shardingsphere/readwritesplitting/route/ReadwriteSplittingDataSourceRouter.java
index 1957f6cd6f3..42904e7f8a0 100644
--- a/features/readwrite-splitting/core/src/main/java/org/apache/shardingsphere/readwritesplitting/route/ReadwriteSplittingDataSourceRouter.java
+++ b/features/readwrite-splitting/core/src/main/java/org/apache/shardingsphere/readwritesplitting/route/ReadwriteSplittingDataSourceRouter.java
@@ -19,11 +19,12 @@ package org.apache.shardingsphere.readwritesplitting.route;
import lombok.RequiredArgsConstructor;
import org.apache.shardingsphere.infra.binder.statement.SQLStatementContext;
+import org.apache.shardingsphere.infra.hint.HintValueContext;
import org.apache.shardingsphere.infra.session.connection.ConnectionContext;
import org.apache.shardingsphere.readwritesplitting.route.qualified.QualifiedReadwriteSplittingDataSourceRouter;
import org.apache.shardingsphere.readwritesplitting.route.qualified.type.QualifiedReadwriteSplittingPrimaryDataSourceRouter;
-import org.apache.shardingsphere.readwritesplitting.route.standard.StandardReadwriteSplittingDataSourceRouter;
import org.apache.shardingsphere.readwritesplitting.route.qualified.type.QualifiedReadwriteSplittingTransactionalDataSourceRouter;
+import org.apache.shardingsphere.readwritesplitting.route.standard.StandardReadwriteSplittingDataSourceRouter;
import org.apache.shardingsphere.readwritesplitting.rule.ReadwriteSplittingDataSourceRule;
import java.util.Arrays;
@@ -41,13 +42,14 @@ public final class ReadwriteSplittingDataSourceRouter {
/**
* Route.
- *
+ *
* @param sqlStatementContext SQL statement context
+ * @param hintValueContext hint value context
* @return data source name
*/
- public String route(final SQLStatementContext sqlStatementContext) {
+ public String route(final SQLStatementContext sqlStatementContext, final HintValueContext hintValueContext) {
for (QualifiedReadwriteSplittingDataSourceRouter each : getQualifiedRouters(connectionContext)) {
- if (each.isQualified(sqlStatementContext, rule)) {
+ if (each.isQualified(sqlStatementContext, rule, hintValueContext)) {
return each.route(rule);
}
}
diff --git a/features/readwrite-splitting/core/src/main/java/org/apache/shardingsphere/readwritesplitting/route/ReadwriteSplittingSQLRouter.java b/features/readwrite-splitting/core/src/main/java/org/apache/shardingsphere/readwritesplitting/route/ReadwriteSplittingSQLRouter.java
index 1198ca7f7ad..99940859ab9 100644
--- a/features/readwrite-splitting/core/src/main/java/org/apache/shardingsphere/readwritesplitting/route/ReadwriteSplittingSQLRouter.java
+++ b/features/readwrite-splitting/core/src/main/java/org/apache/shardingsphere/readwritesplitting/route/ReadwriteSplittingSQLRouter.java
@@ -17,15 +17,15 @@
package org.apache.shardingsphere.readwritesplitting.route;
-import org.apache.shardingsphere.infra.session.query.QueryContext;
import org.apache.shardingsphere.infra.config.props.ConfigurationProperties;
-import org.apache.shardingsphere.infra.session.connection.ConnectionContext;
import org.apache.shardingsphere.infra.metadata.database.ShardingSphereDatabase;
import org.apache.shardingsphere.infra.metadata.database.rule.ShardingSphereRuleMetaData;
import org.apache.shardingsphere.infra.route.SQLRouter;
import org.apache.shardingsphere.infra.route.context.RouteContext;
import org.apache.shardingsphere.infra.route.context.RouteMapper;
import org.apache.shardingsphere.infra.route.context.RouteUnit;
+import org.apache.shardingsphere.infra.session.connection.ConnectionContext;
+import org.apache.shardingsphere.infra.session.query.QueryContext;
import org.apache.shardingsphere.readwritesplitting.constant.ReadwriteSplittingOrder;
import org.apache.shardingsphere.readwritesplitting.rule.ReadwriteSplittingDataSourceRule;
import org.apache.shardingsphere.readwritesplitting.rule.ReadwriteSplittingRule;
@@ -45,7 +45,7 @@ public final class ReadwriteSplittingSQLRouter implements SQLRouter<ReadwriteSpl
final ShardingSphereDatabase database, final ReadwriteSplittingRule rule, final ConfigurationProperties props, final ConnectionContext connectionContext) {
RouteContext result = new RouteContext();
ReadwriteSplittingDataSourceRule singleDataSourceRule = rule.getSingleDataSourceRule();
- String dataSourceName = new ReadwriteSplittingDataSourceRouter(singleDataSourceRule, connectionContext).route(queryContext.getSqlStatementContext());
+ String dataSourceName = new ReadwriteSplittingDataSourceRouter(singleDataSourceRule, connectionContext).route(queryContext.getSqlStatementContext(), queryContext.getHintValueContext());
result.getRouteUnits().add(new RouteUnit(new RouteMapper(singleDataSourceRule.getName(), dataSourceName), Collections.emptyList()));
return result;
}
@@ -60,7 +60,8 @@ public final class ReadwriteSplittingSQLRouter implements SQLRouter<ReadwriteSpl
Optional<ReadwriteSplittingDataSourceRule> dataSourceRule = rule.findDataSourceRule(dataSourceName);
if (dataSourceRule.isPresent() && dataSourceRule.get().getName().equalsIgnoreCase(each.getDataSourceMapper().getActualName())) {
toBeRemoved.add(each);
- String actualDataSourceName = new ReadwriteSplittingDataSourceRouter(dataSourceRule.get(), connectionContext).route(queryContext.getSqlStatementContext());
+ String actualDataSourceName = new ReadwriteSplittingDataSourceRouter(dataSourceRule.get(), connectionContext).route(queryContext.getSqlStatementContext(),
+ queryContext.getHintValueContext());
toBeAdded.add(new RouteUnit(new RouteMapper(each.getDataSourceMapper().getLogicName(), actualDataSourceName), each.getTableMappers()));
}
}
diff --git a/features/readwrite-splitting/core/src/main/java/org/apache/shardingsphere/readwritesplitting/route/qualified/QualifiedReadwriteSplittingDataSourceRouter.java b/features/readwrite-splitting/core/src/main/java/org/apache/shardingsphere/readwritesplitting/route/qualified/QualifiedReadwriteSplittingDataSourceRouter.java
index f3b277cba4d..7779a268426 100644
--- a/features/readwrite-splitting/core/src/main/java/org/apache/shardingsphere/readwritesplitting/route/qualified/QualifiedReadwriteSplittingDataSourceRouter.java
+++ b/features/readwrite-splitting/core/src/main/java/org/apache/shardingsphere/readwritesplitting/route/qualified/QualifiedReadwriteSplittingDataSourceRouter.java
@@ -18,6 +18,7 @@
package org.apache.shardingsphere.readwritesplitting.route.qualified;
import org.apache.shardingsphere.infra.binder.statement.SQLStatementContext;
+import org.apache.shardingsphere.infra.hint.HintValueContext;
import org.apache.shardingsphere.readwritesplitting.rule.ReadwriteSplittingDataSourceRule;
/**
@@ -30,9 +31,11 @@ public interface QualifiedReadwriteSplittingDataSourceRouter {
*
* @param sqlStatementContext SQL statement context
* @param rule readwrite splitting datasource rule
+ * @param hintValueContext hint value context
+ *
* @return qualified to route or not
*/
- boolean isQualified(SQLStatementContext sqlStatementContext, ReadwriteSplittingDataSourceRule rule);
+ boolean isQualified(SQLStatementContext sqlStatementContext, ReadwriteSplittingDataSourceRule rule, HintValueContext hintValueContext);
/**
* Route to data source.
diff --git a/features/readwrite-splitting/core/src/main/java/org/apache/shardingsphere/readwritesplitting/route/qualified/type/QualifiedReadwriteSplittingPrimaryDataSourceRouter.java b/features/readwrite-splitting/core/src/main/java/org/apache/shardingsphere/readwritesplitting/route/qualified/type/QualifiedReadwriteSplittingPrimaryDataSourceRouter.java
index 5873098f949..192c9847ad5 100644
--- a/features/readwrite-splitting/core/src/main/java/org/apache/shardingsphere/readwritesplitting/route/qualified/type/QualifiedReadwriteSplittingPrimaryDataSourceRouter.java
+++ b/features/readwrite-splitting/core/src/main/java/org/apache/shardingsphere/readwritesplitting/route/qualified/type/QualifiedReadwriteSplittingPrimaryDataSourceRouter.java
@@ -17,10 +17,10 @@
package org.apache.shardingsphere.readwritesplitting.route.qualified.type;
-import org.apache.shardingsphere.infra.binder.statement.CommonSQLStatementContext;
import org.apache.shardingsphere.infra.binder.statement.SQLStatementContext;
import org.apache.shardingsphere.infra.binder.statement.dml.SelectStatementContext;
import org.apache.shardingsphere.infra.hint.HintManager;
+import org.apache.shardingsphere.infra.hint.HintValueContext;
import org.apache.shardingsphere.readwritesplitting.route.qualified.QualifiedReadwriteSplittingDataSourceRouter;
import org.apache.shardingsphere.readwritesplitting.rule.ReadwriteSplittingDataSourceRule;
import org.apache.shardingsphere.sql.parser.sql.common.statement.SQLStatement;
@@ -33,12 +33,12 @@ import org.apache.shardingsphere.sql.parser.sql.dialect.handler.dml.SelectStatem
public final class QualifiedReadwriteSplittingPrimaryDataSourceRouter implements QualifiedReadwriteSplittingDataSourceRouter {
@Override
- public boolean isQualified(final SQLStatementContext sqlStatementContext, final ReadwriteSplittingDataSourceRule rule) {
- return isPrimaryRoute(sqlStatementContext);
+ public boolean isQualified(final SQLStatementContext sqlStatementContext, final ReadwriteSplittingDataSourceRule rule, final HintValueContext hintValueContext) {
+ return isPrimaryRoute(sqlStatementContext, hintValueContext);
}
- private boolean isPrimaryRoute(final SQLStatementContext sqlStatementContext) {
- return isWriteRouteStatement(sqlStatementContext) || isHintWriteRouteOnly(sqlStatementContext);
+ private boolean isPrimaryRoute(final SQLStatementContext sqlStatementContext, final HintValueContext hintValueContext) {
+ return isWriteRouteStatement(sqlStatementContext) || isHintWriteRouteOnly(hintValueContext);
}
private boolean isWriteRouteStatement(final SQLStatementContext sqlStatementContext) {
@@ -54,8 +54,8 @@ public final class QualifiedReadwriteSplittingPrimaryDataSourceRouter implements
return sqlStatementContext instanceof SelectStatementContext && ((SelectStatementContext) sqlStatementContext).getProjectionsContext().isContainsLastInsertIdProjection();
}
- private boolean isHintWriteRouteOnly(final SQLStatementContext sqlStatementContext) {
- return HintManager.isWriteRouteOnly() || sqlStatementContext instanceof CommonSQLStatementContext && ((CommonSQLStatementContext) sqlStatementContext).isHintWriteRouteOnly();
+ private boolean isHintWriteRouteOnly(final HintValueContext hintValueContext) {
+ return HintManager.isWriteRouteOnly() || hintValueContext.isWriteRouteOnly();
}
@Override
diff --git a/features/readwrite-splitting/core/src/main/java/org/apache/shardingsphere/readwritesplitting/route/qualified/type/QualifiedReadwriteSplittingTransactionalDataSourceRouter.java b/features/readwrite-splitting/core/src/main/java/org/apache/shardingsphere/readwritesplitting/route/qualified/type/QualifiedReadwriteSplittingTransactionalDataSourceRouter.java
index d5ce2692091..4ac20b0b499 100644
--- a/features/readwrite-splitting/core/src/main/java/org/apache/shardingsphere/readwritesplitting/route/qualified/type/QualifiedReadwriteSplittingTransactionalDataSourceRouter.java
+++ b/features/readwrite-splitting/core/src/main/java/org/apache/shardingsphere/readwritesplitting/route/qualified/type/QualifiedReadwriteSplittingTransactionalDataSourceRouter.java
@@ -19,6 +19,7 @@ package org.apache.shardingsphere.readwritesplitting.route.qualified.type;
import lombok.RequiredArgsConstructor;
import org.apache.shardingsphere.infra.binder.statement.SQLStatementContext;
+import org.apache.shardingsphere.infra.hint.HintValueContext;
import org.apache.shardingsphere.infra.session.connection.ConnectionContext;
import org.apache.shardingsphere.readwritesplitting.route.qualified.QualifiedReadwriteSplittingDataSourceRouter;
import org.apache.shardingsphere.readwritesplitting.route.standard.StandardReadwriteSplittingDataSourceRouter;
@@ -35,7 +36,7 @@ public final class QualifiedReadwriteSplittingTransactionalDataSourceRouter impl
private final StandardReadwriteSplittingDataSourceRouter standardRouter = new StandardReadwriteSplittingDataSourceRouter();
@Override
- public boolean isQualified(final SQLStatementContext sqlStatementContext, final ReadwriteSplittingDataSourceRule rule) {
+ public boolean isQualified(final SQLStatementContext sqlStatementContext, final ReadwriteSplittingDataSourceRule rule, final HintValueContext hintValueContext) {
return connectionContext.getTransactionContext().isInTransaction();
}
diff --git a/features/readwrite-splitting/core/src/test/java/org/apache/shardingsphere/readwritesplitting/route/ReadwriteSplittingSQLRouterTest.java b/features/readwrite-splitting/core/src/test/java/org/apache/shardingsphere/readwritesplitting/route/ReadwriteSplittingSQLRouterTest.java
index 0b0fbd4b0b2..b36a66484fd 100644
--- a/features/readwrite-splitting/core/src/test/java/org/apache/shardingsphere/readwritesplitting/route/ReadwriteSplittingSQLRouterTest.java
+++ b/features/readwrite-splitting/core/src/test/java/org/apache/shardingsphere/readwritesplitting/route/ReadwriteSplittingSQLRouterTest.java
@@ -23,6 +23,7 @@ import org.apache.shardingsphere.infra.binder.statement.dml.SelectStatementConte
import org.apache.shardingsphere.infra.config.props.ConfigurationProperties;
import org.apache.shardingsphere.infra.database.DefaultDatabase;
import org.apache.shardingsphere.infra.database.type.DatabaseType;
+import org.apache.shardingsphere.infra.hint.HintValueContext;
import org.apache.shardingsphere.infra.instance.InstanceContext;
import org.apache.shardingsphere.infra.metadata.database.ShardingSphereDatabase;
import org.apache.shardingsphere.infra.metadata.database.resource.ShardingSphereResourceMetaData;
@@ -219,9 +220,10 @@ class ReadwriteSplittingSQLRouterTest {
SelectStatement statement = mock(SelectStatement.class);
SelectStatementContext sqlStatementContext = mock(SelectStatementContext.class, RETURNS_DEEP_STUBS);
when(sqlStatementContext.getSqlStatement()).thenReturn(statement);
- when(sqlStatementContext.isHintWriteRouteOnly()).thenReturn(true);
+ HintValueContext hintValueContext = mock(HintValueContext.class);
+ when(hintValueContext.isWriteRouteOnly()).thenReturn(true);
when(sqlStatementContext.getProjectionsContext().isContainsLastInsertIdProjection()).thenReturn(false);
- QueryContext queryContext = new QueryContext(sqlStatementContext, "", Collections.emptyList());
+ QueryContext queryContext = new QueryContext(sqlStatementContext, "", Collections.emptyList(), hintValueContext);
ShardingSphereRuleMetaData ruleMetaData = new ShardingSphereRuleMetaData(Collections.singleton(staticRule));
ShardingSphereDatabase database = new ShardingSphereDatabase(DefaultDatabase.LOGIC_NAME,
mock(DatabaseType.class), mock(ShardingSphereResourceMetaData.class, RETURNS_DEEP_STUBS), ruleMetaData, Collections.emptyMap());
diff --git a/features/readwrite-splitting/core/src/test/java/org/apache/shardingsphere/readwritesplitting/route/qualified/type/QualifiedReadwriteSplittingPrimaryDataSourceRouterTest.java b/features/readwrite-splitting/core/src/test/java/org/apache/shardingsphere/readwritesplitting/route/qualified/type/QualifiedReadwriteSplittingPrimaryDataSourceRouterTest.java
index b549b3cef97..7f0cd1192fd 100644
--- a/features/readwrite-splitting/core/src/test/java/org/apache/shardingsphere/readwritesplitting/route/qualified/type/QualifiedReadwriteSplittingPrimaryDataSourceRouterTest.java
+++ b/features/readwrite-splitting/core/src/test/java/org/apache/shardingsphere/readwritesplitting/route/qualified/type/QualifiedReadwriteSplittingPrimaryDataSourceRouterTest.java
@@ -18,6 +18,7 @@
package org.apache.shardingsphere.readwritesplitting.route.qualified.type;
import org.apache.shardingsphere.infra.binder.statement.CommonSQLStatementContext;
+import org.apache.shardingsphere.infra.hint.HintValueContext;
import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.predicate.LockSegment;
import org.apache.shardingsphere.sql.parser.sql.common.statement.dml.SelectStatement;
import org.apache.shardingsphere.sql.parser.sql.dialect.statement.mysql.dml.MySQLSelectStatement;
@@ -40,22 +41,25 @@ class QualifiedReadwriteSplittingPrimaryDataSourceRouterTest {
@Mock
private CommonSQLStatementContext sqlStatementContext;
+ @Mock
+ private HintValueContext hintValueContext;
+
@Test
void assertWriteRouteStatement() {
MySQLSelectStatement selectStatement = mock(MySQLSelectStatement.class);
when(selectStatement.getLock()).thenReturn(Optional.of(new LockSegment(0, 1)));
when(sqlStatementContext.getSqlStatement()).thenReturn(selectStatement);
- assertTrue(new QualifiedReadwriteSplittingPrimaryDataSourceRouter().isQualified(sqlStatementContext, null));
+ assertTrue(new QualifiedReadwriteSplittingPrimaryDataSourceRouter().isQualified(sqlStatementContext, null, hintValueContext));
when(sqlStatementContext.getSqlStatement()).thenReturn(mock(MySQLUpdateStatement.class));
- assertTrue(new QualifiedReadwriteSplittingPrimaryDataSourceRouter().isQualified(sqlStatementContext, null));
+ assertTrue(new QualifiedReadwriteSplittingPrimaryDataSourceRouter().isQualified(sqlStatementContext, null, hintValueContext));
}
@Test
void assertHintRouteWriteOnly() {
when(sqlStatementContext.getSqlStatement()).thenReturn(mock(SelectStatement.class));
- when(sqlStatementContext.isHintWriteRouteOnly()).thenReturn(false);
- assertFalse(new QualifiedReadwriteSplittingPrimaryDataSourceRouter().isQualified(sqlStatementContext, null));
- when(sqlStatementContext.isHintWriteRouteOnly()).thenReturn(true);
- assertTrue(new QualifiedReadwriteSplittingPrimaryDataSourceRouter().isQualified(sqlStatementContext, null));
+ when(hintValueContext.isWriteRouteOnly()).thenReturn(false);
+ assertFalse(new QualifiedReadwriteSplittingPrimaryDataSourceRouter().isQualified(sqlStatementContext, null, hintValueContext));
+ when(hintValueContext.isWriteRouteOnly()).thenReturn(true);
+ assertTrue(new QualifiedReadwriteSplittingPrimaryDataSourceRouter().isQualified(sqlStatementContext, null, hintValueContext));
}
}
diff --git a/features/readwrite-splitting/core/src/test/java/org/apache/shardingsphere/readwritesplitting/route/qualified/type/QualifiedReadwriteSplittingTransactionalDataSourceRouterTest.java b/features/readwrite-splitting/core/src/test/java/org/apache/shardingsphere/readwritesplitting/route/qualified/type/QualifiedReadwriteSplittingTransactionalDataSourceRouterTest.java
index ab91cd98711..c26021e835b 100644
--- a/features/readwrite-splitting/core/src/test/java/org/apache/shardingsphere/readwritesplitting/route/qualified/type/QualifiedReadwriteSplittingTransactionalDataSourceRouterTest.java
+++ b/features/readwrite-splitting/core/src/test/java/org/apache/shardingsphere/readwritesplitting/route/qualified/type/QualifiedReadwriteSplittingTransactionalDataSourceRouterTest.java
@@ -17,6 +17,7 @@
package org.apache.shardingsphere.readwritesplitting.route.qualified.type;
+import org.apache.shardingsphere.infra.hint.HintValueContext;
import org.apache.shardingsphere.infra.session.connection.ConnectionContext;
import org.apache.shardingsphere.infra.session.connection.transaction.TransactionConnectionContext;
import org.apache.shardingsphere.readwritesplitting.algorithm.loadbalance.RoundRobinReadQueryLoadBalanceAlgorithm;
@@ -24,6 +25,7 @@ import org.apache.shardingsphere.readwritesplitting.api.rule.ReadwriteSplittingD
import org.apache.shardingsphere.readwritesplitting.api.transaction.TransactionalReadQueryStrategy;
import org.apache.shardingsphere.readwritesplitting.rule.ReadwriteSplittingDataSourceRule;
import org.junit.jupiter.api.Test;
+import org.mockito.Mock;
import java.util.Arrays;
@@ -36,15 +38,18 @@ import static org.mockito.Mockito.when;
class QualifiedReadwriteSplittingTransactionalDataSourceRouterTest {
+ @Mock
+ private HintValueContext hintValueContext;
+
@Test
void assertWriteRouteTransaction() {
ConnectionContext connectionContext = mock(ConnectionContext.class);
TransactionConnectionContext transactionConnectionContext = mock(TransactionConnectionContext.class);
when(connectionContext.getTransactionContext()).thenReturn(transactionConnectionContext);
when(connectionContext.getTransactionContext().isInTransaction()).thenReturn(Boolean.TRUE);
- assertTrue(new QualifiedReadwriteSplittingTransactionalDataSourceRouter(connectionContext).isQualified(null, null));
+ assertTrue(new QualifiedReadwriteSplittingTransactionalDataSourceRouter(connectionContext).isQualified(null, null, hintValueContext));
when(connectionContext.getTransactionContext().isInTransaction()).thenReturn(Boolean.FALSE);
- assertFalse(new QualifiedReadwriteSplittingTransactionalDataSourceRouter(connectionContext).isQualified(null, null));
+ assertFalse(new QualifiedReadwriteSplittingTransactionalDataSourceRouter(connectionContext).isQualified(null, null, hintValueContext));
}
@Test
diff --git a/features/sharding/core/src/main/java/org/apache/shardingsphere/sharding/auditor/ShardingSQLAuditor.java b/features/sharding/core/src/main/java/org/apache/shardingsphere/sharding/auditor/ShardingSQLAuditor.java
index f16c7ed7cf7..2467be2ecca 100644
--- a/features/sharding/core/src/main/java/org/apache/shardingsphere/sharding/auditor/ShardingSQLAuditor.java
+++ b/features/sharding/core/src/main/java/org/apache/shardingsphere/sharding/auditor/ShardingSQLAuditor.java
@@ -17,9 +17,9 @@
package org.apache.shardingsphere.sharding.auditor;
-import org.apache.shardingsphere.infra.binder.statement.CommonSQLStatementContext;
import org.apache.shardingsphere.infra.binder.statement.SQLStatementContext;
import org.apache.shardingsphere.infra.executor.audit.SQLAuditor;
+import org.apache.shardingsphere.infra.hint.HintValueContext;
import org.apache.shardingsphere.infra.metadata.database.ShardingSphereDatabase;
import org.apache.shardingsphere.infra.metadata.database.rule.ShardingSphereRuleMetaData;
import org.apache.shardingsphere.infra.metadata.user.Grantee;
@@ -29,7 +29,6 @@ import org.apache.shardingsphere.sharding.rule.ShardingRule;
import java.util.ArrayList;
import java.util.Collection;
-import java.util.Collections;
import java.util.List;
/**
@@ -39,14 +38,12 @@ public final class ShardingSQLAuditor implements SQLAuditor<ShardingRule> {
@Override
public void audit(final SQLStatementContext sqlStatementContext, final List<Object> params, final Grantee grantee, final ShardingSphereRuleMetaData globalRuleMetaData,
- final ShardingSphereDatabase database, final ShardingRule rule) {
+ final ShardingSphereDatabase database, final ShardingRule rule, final HintValueContext hintValueContext) {
Collection<ShardingAuditStrategyConfiguration> auditStrategies = getShardingAuditStrategies(sqlStatementContext, rule);
if (auditStrategies.isEmpty()) {
return;
}
- Collection<String> disableAuditNames = sqlStatementContext instanceof CommonSQLStatementContext
- ? ((CommonSQLStatementContext) sqlStatementContext).getSqlHintExtractor().findDisableAuditNames()
- : Collections.emptyList();
+ Collection<String> disableAuditNames = hintValueContext.findDisableAuditNames();
for (ShardingAuditStrategyConfiguration auditStrategy : auditStrategies) {
for (String auditorName : auditStrategy.getAuditorNames()) {
if (!auditStrategy.isAllowHintDisable() || !disableAuditNames.contains(auditorName.toLowerCase())) {
diff --git a/features/sharding/core/src/test/java/org/apache/shardingsphere/sharding/auditor/ShardingSQLAuditorTest.java b/features/sharding/core/src/test/java/org/apache/shardingsphere/sharding/auditor/ShardingSQLAuditorTest.java
index 64ddb7213c7..a4126cff122 100644
--- a/features/sharding/core/src/test/java/org/apache/shardingsphere/sharding/auditor/ShardingSQLAuditorTest.java
+++ b/features/sharding/core/src/test/java/org/apache/shardingsphere/sharding/auditor/ShardingSQLAuditorTest.java
@@ -19,6 +19,7 @@ package org.apache.shardingsphere.sharding.auditor;
import org.apache.shardingsphere.infra.binder.statement.CommonSQLStatementContext;
import org.apache.shardingsphere.infra.executor.audit.exception.SQLAuditException;
+import org.apache.shardingsphere.infra.hint.HintValueContext;
import org.apache.shardingsphere.infra.metadata.database.ShardingSphereDatabase;
import org.apache.shardingsphere.infra.metadata.database.rule.ShardingSphereRuleMetaData;
import org.apache.shardingsphere.infra.metadata.user.Grantee;
@@ -62,11 +63,14 @@ class ShardingSQLAuditorTest {
@Mock
private ShardingAuditStrategyConfiguration auditStrategy;
+ @Mock
+ private HintValueContext hintValueContext;
+
private final Map<String, ShardingSphereDatabase> databases = Collections.singletonMap("foo_db", mock(ShardingSphereDatabase.class));
@BeforeEach
void setUp() {
- when(sqlStatementContext.getSqlHintExtractor().findDisableAuditNames()).thenReturn(new HashSet<>(Collections.singletonList("auditor_1")));
+ when(hintValueContext.findDisableAuditNames()).thenReturn(new HashSet<>(Collections.singletonList("auditor_1")));
when(sqlStatementContext.getTablesContext().getTableNames()).thenReturn(Collections.singletonList("foo_table"));
TableRule tableRule = mock(TableRule.class);
when(rule.findTableRule("foo_table")).thenReturn(Optional.of(tableRule));
@@ -77,7 +81,7 @@ class ShardingSQLAuditorTest {
@Test
void assertCheckSuccess() {
ShardingSphereRuleMetaData globalRuleMetaData = mock(ShardingSphereRuleMetaData.class);
- new ShardingSQLAuditor().audit(sqlStatementContext, Collections.emptyList(), grantee, globalRuleMetaData, databases.get("foo_db"), rule);
+ new ShardingSQLAuditor().audit(sqlStatementContext, Collections.emptyList(), grantee, globalRuleMetaData, databases.get("foo_db"), rule, hintValueContext);
verify(rule.getAuditors().get("auditor_1")).check(sqlStatementContext, Collections.emptyList(), grantee, globalRuleMetaData, databases.get("foo_db"));
}
@@ -85,7 +89,7 @@ class ShardingSQLAuditorTest {
void assertCheckSuccessByDisableAuditNames() {
when(auditStrategy.isAllowHintDisable()).thenReturn(true);
ShardingSphereRuleMetaData globalRuleMetaData = mock(ShardingSphereRuleMetaData.class);
- new ShardingSQLAuditor().audit(sqlStatementContext, Collections.emptyList(), grantee, globalRuleMetaData, databases.get("foo_db"), rule);
+ new ShardingSQLAuditor().audit(sqlStatementContext, Collections.emptyList(), grantee, globalRuleMetaData, databases.get("foo_db"), rule, hintValueContext);
verify(rule.getAuditors().get("auditor_1"), times(0)).check(sqlStatementContext, Collections.emptyList(), grantee, globalRuleMetaData, databases.get("foo_db"));
}
@@ -96,7 +100,7 @@ class ShardingSQLAuditorTest {
doThrow(new SQLAuditException("Not allow DML operation without sharding conditions"))
.when(auditAlgorithm).check(sqlStatementContext, Collections.emptyList(), grantee, globalRuleMetaData, databases.get("foo_db"));
SQLAuditException ex = assertThrows(SQLAuditException.class,
- () -> new ShardingSQLAuditor().audit(sqlStatementContext, Collections.emptyList(), grantee, globalRuleMetaData, databases.get("foo_db"), rule));
+ () -> new ShardingSQLAuditor().audit(sqlStatementContext, Collections.emptyList(), grantee, globalRuleMetaData, databases.get("foo_db"), rule, hintValueContext));
assertThat(ex.getMessage(), is("SQL audit failed, error message: Not allow DML operation without sharding conditions."));
verify(rule.getAuditors().get("auditor_1")).check(sqlStatementContext, Collections.emptyList(), grantee, globalRuleMetaData, databases.get("foo_db"));
}
diff --git a/features/sharding/core/src/test/java/org/apache/shardingsphere/sharding/cache/route/CachedShardingSQLRouterTest.java b/features/sharding/core/src/test/java/org/apache/shardingsphere/sharding/cache/route/CachedShardingSQLRouterTest.java
index bdb37a9876e..173b526b952 100644
--- a/features/sharding/core/src/test/java/org/apache/shardingsphere/sharding/cache/route/CachedShardingSQLRouterTest.java
+++ b/features/sharding/core/src/test/java/org/apache/shardingsphere/sharding/cache/route/CachedShardingSQLRouterTest.java
@@ -17,6 +17,7 @@
package org.apache.shardingsphere.sharding.cache.route;
+import org.apache.shardingsphere.infra.binder.statement.SQLStatementContext;
import org.apache.shardingsphere.infra.datanode.DataNode;
import org.apache.shardingsphere.infra.metadata.database.rule.ShardingSphereRuleMetaData;
import org.apache.shardingsphere.infra.route.context.RouteContext;
@@ -57,17 +58,20 @@ class CachedShardingSQLRouterTest {
@Mock
private ShardingCache shardingCache;
+ @Mock
+ private SQLStatementContext sqlStatementContext;
+
@Test
void assertCreateRouteContextWithSQLExceedMaxAllowedLength() {
when(shardingCache.getConfiguration()).thenReturn(new ShardingCacheConfiguration(1, null));
- QueryContext queryContext = new QueryContext(null, "select 1", Collections.emptyList());
+ QueryContext queryContext = new QueryContext(sqlStatementContext, "select 1", Collections.emptyList());
Optional<RouteContext> actual = new CachedShardingSQLRouter().loadRouteContext(null, queryContext, mock(ShardingSphereRuleMetaData.class), null, shardingCache, null, null);
assertFalse(actual.isPresent());
}
@Test
void assertCreateRouteContextWithNotCacheableQuery() {
- QueryContext queryContext = new QueryContext(null, "insert into t values (?), (?)", Collections.emptyList());
+ QueryContext queryContext = new QueryContext(sqlStatementContext, "insert into t values (?), (?)", Collections.emptyList());
when(shardingCache.getConfiguration()).thenReturn(new ShardingCacheConfiguration(100, null));
when(shardingCache.getRouteCacheableChecker()).thenReturn(mock(ShardingRouteCacheableChecker.class));
when(shardingCache.getRouteCacheableChecker().check(null, queryContext)).thenReturn(new ShardingRouteCacheableCheckResult(false, Collections.emptyList()));
@@ -77,7 +81,7 @@ class CachedShardingSQLRouterTest {
@Test
void assertCreateRouteContextWithUnmatchedActualParameterSize() {
- QueryContext queryContext = new QueryContext(null, "insert into t values (?, ?)", Collections.singletonList(0));
+ QueryContext queryContext = new QueryContext(sqlStatementContext, "insert into t values (?, ?)", Collections.singletonList(0));
when(shardingCache.getConfiguration()).thenReturn(new ShardingCacheConfiguration(100, null));
when(shardingCache.getRouteCacheableChecker()).thenReturn(mock(ShardingRouteCacheableChecker.class));
when(shardingCache.getRouteCacheableChecker().check(null, queryContext)).thenReturn(new ShardingRouteCacheableCheckResult(true, Collections.singletonList(1)));
@@ -87,7 +91,7 @@ class CachedShardingSQLRouterTest {
@Test
void assertCreateRouteContextWithCacheableQueryButCacheMissed() {
- QueryContext queryContext = new QueryContext(null, "insert into t values (?, ?)", Arrays.asList(0, 1));
+ QueryContext queryContext = new QueryContext(sqlStatementContext, "insert into t values (?, ?)", Arrays.asList(0, 1));
when(shardingCache.getConfiguration()).thenReturn(new ShardingCacheConfiguration(100, null));
when(shardingCache.getRouteCacheableChecker()).thenReturn(mock(ShardingRouteCacheableChecker.class));
when(shardingCache.getRouteCacheableChecker().check(null, queryContext)).thenReturn(new ShardingRouteCacheableCheckResult(true, Collections.singletonList(1)));
@@ -105,7 +109,7 @@ class CachedShardingSQLRouterTest {
@Test
void assertCreateRouteContextWithCacheHit() {
- QueryContext queryContext = new QueryContext(null, "insert into t values (?, ?)", Arrays.asList(0, 1));
+ QueryContext queryContext = new QueryContext(sqlStatementContext, "insert into t values (?, ?)", Arrays.asList(0, 1));
when(shardingCache.getConfiguration()).thenReturn(new ShardingCacheConfiguration(100, null));
when(shardingCache.getRouteCacheableChecker()).thenReturn(mock(ShardingRouteCacheableChecker.class));
when(shardingCache.getRouteCacheableChecker().check(null, queryContext)).thenReturn(new ShardingRouteCacheableCheckResult(true, Collections.singletonList(1)));
@@ -124,7 +128,7 @@ class CachedShardingSQLRouterTest {
@Test
void assertCreateRouteContextWithQueryRoutedToMultiDataNodes() {
- QueryContext queryContext = new QueryContext(null, "select * from t", Collections.emptyList());
+ QueryContext queryContext = new QueryContext(sqlStatementContext, "select * from t", Collections.emptyList());
when(shardingCache.getConfiguration()).thenReturn(new ShardingCacheConfiguration(100, null));
when(shardingCache.getRouteCacheableChecker()).thenReturn(mock(ShardingRouteCacheableChecker.class));
when(shardingCache.getRouteCacheableChecker().check(null, queryContext)).thenReturn(new ShardingRouteCacheableCheckResult(true, Collections.emptyList()));
diff --git a/infra/binder/src/main/java/org/apache/shardingsphere/infra/binder/statement/CommonSQLStatementContext.java b/infra/binder/src/main/java/org/apache/shardingsphere/infra/binder/statement/CommonSQLStatementContext.java
index 8a17de2000e..8892dffd267 100644
--- a/infra/binder/src/main/java/org/apache/shardingsphere/infra/binder/statement/CommonSQLStatementContext.java
+++ b/infra/binder/src/main/java/org/apache/shardingsphere/infra/binder/statement/CommonSQLStatementContext.java
@@ -20,7 +20,6 @@ package org.apache.shardingsphere.infra.binder.statement;
import lombok.Getter;
import org.apache.shardingsphere.infra.binder.segment.table.TablesContext;
import org.apache.shardingsphere.infra.database.type.DatabaseType;
-import org.apache.shardingsphere.infra.hint.SQLHintExtractor;
import org.apache.shardingsphere.infra.util.exception.external.sql.type.generic.UnsupportedSQLOperationException;
import org.apache.shardingsphere.infra.util.spi.type.typed.TypedSPILoader;
import org.apache.shardingsphere.sql.parser.sql.common.statement.SQLStatement;
@@ -32,7 +31,6 @@ import org.apache.shardingsphere.sql.parser.sql.dialect.statement.sql92.SQL92Sta
import org.apache.shardingsphere.sql.parser.sql.dialect.statement.sqlserver.SQLServerStatement;
import java.util.Collections;
-import java.util.Optional;
/**
* Common SQL statement context.
@@ -46,13 +44,10 @@ public abstract class CommonSQLStatementContext implements SQLStatementContext {
private final DatabaseType databaseType;
- private final SQLHintExtractor sqlHintExtractor;
-
protected CommonSQLStatementContext(final SQLStatement sqlStatement) {
this.sqlStatement = sqlStatement;
databaseType = getDatabaseType(sqlStatement);
tablesContext = new TablesContext(Collections.emptyList(), databaseType);
- sqlHintExtractor = new SQLHintExtractor(sqlStatement);
}
private DatabaseType getDatabaseType(final SQLStatement sqlStatement) {
@@ -76,31 +71,4 @@ public abstract class CommonSQLStatementContext implements SQLStatementContext {
}
throw new UnsupportedSQLOperationException(sqlStatement.getClass().getName());
}
-
- /**
- * Find hint data source name.
- *
- * @return dataSource name
- */
- public Optional<String> findHintDataSourceName() {
- return sqlHintExtractor.findHintDataSourceName();
- }
-
- /**
- * Judge whether is hint routed to write data source or not.
- *
- * @return whether is hint routed to write data source or not
- */
- public boolean isHintWriteRouteOnly() {
- return sqlHintExtractor.isHintWriteRouteOnly();
- }
-
- /**
- * Judge whether hint skip sql rewrite or not.
- *
- * @return whether hint skip sql rewrite or not
- */
- public boolean isHintSkipSQLRewrite() {
- return sqlHintExtractor.isHintSkipSQLRewrite();
- }
}
diff --git a/infra/common/src/main/java/org/apache/shardingsphere/infra/hint/HintValueContext.java b/infra/common/src/main/java/org/apache/shardingsphere/infra/hint/HintValueContext.java
index 6a8ce5c18f9..0372a3cf41b 100644
--- a/infra/common/src/main/java/org/apache/shardingsphere/infra/hint/HintValueContext.java
+++ b/infra/common/src/main/java/org/apache/shardingsphere/infra/hint/HintValueContext.java
@@ -23,6 +23,9 @@ import lombok.Getter;
import lombok.Setter;
import lombok.ToString;
+import java.util.Collection;
+import java.util.Optional;
+
/**
* Hint value context.
*/
@@ -48,4 +51,22 @@ public final class HintValueContext {
private String disableAuditNames = "";
private boolean shadow;
+
+ /**
+ * Find hint disable audit names.
+ *
+ * @return disable audit names
+ */
+ public Collection<String> findDisableAuditNames() {
+ return SQLHintUtils.getSplitterSQLHintValue(disableAuditNames);
+ }
+
+ /**
+ * Find hint data source name.
+ *
+ * @return data source name
+ */
+ public Optional<String> findHintDataSourceName() {
+ return dataSourceName.isEmpty() ? Optional.empty() : Optional.of(dataSourceName);
+ }
}
diff --git a/infra/common/src/main/java/org/apache/shardingsphere/infra/hint/SQLHintExtractor.java b/infra/common/src/main/java/org/apache/shardingsphere/infra/hint/SQLHintExtractor.java
index 76bf7029e24..443fdb107b0 100644
--- a/infra/common/src/main/java/org/apache/shardingsphere/infra/hint/SQLHintExtractor.java
+++ b/infra/common/src/main/java/org/apache/shardingsphere/infra/hint/SQLHintExtractor.java
@@ -24,7 +24,6 @@ import org.apache.shardingsphere.sql.parser.sql.common.statement.AbstractSQLStat
import org.apache.shardingsphere.sql.parser.sql.common.statement.SQLStatement;
import java.util.Collection;
-import java.util.Optional;
/**
* SQL hint extractor.
@@ -35,7 +34,7 @@ public final class SQLHintExtractor {
private final HintValueContext hintValueContext;
public SQLHintExtractor(final String sqlComment) {
- hintValueContext = Strings.isNullOrEmpty(sqlComment) ? new HintValueContext() : SQLHintUtils.extractHint(sqlComment);
+ hintValueContext = Strings.isNullOrEmpty(sqlComment) ? new HintValueContext() : SQLHintUtils.extractHint(sqlComment).orElseGet(HintValueContext::new);
}
public SQLHintExtractor(final SQLStatement sqlStatement) {
@@ -44,20 +43,10 @@ public final class SQLHintExtractor {
public SQLHintExtractor(final SQLStatement sqlStatement, final HintValueContext hintValueContext) {
this.hintValueContext = sqlStatement instanceof AbstractSQLStatement && !((AbstractSQLStatement) sqlStatement).getCommentSegments().isEmpty()
- ? SQLHintUtils.extractHint(((AbstractSQLStatement) sqlStatement).getCommentSegments().iterator().next().getText())
+ ? SQLHintUtils.extractHint(((AbstractSQLStatement) sqlStatement).getCommentSegments().iterator().next().getText()).orElse(hintValueContext)
: hintValueContext;
}
- /**
- * Find hint data source name.
- *
- * @return data source name
- */
- public Optional<String> findHintDataSourceName() {
- String result = hintValueContext.getDataSourceName();
- return result.isEmpty() ? Optional.empty() : Optional.of(result);
- }
-
/**
* Judge whether is hint routed to write data source or not.
*
diff --git a/infra/common/src/main/java/org/apache/shardingsphere/infra/hint/SQLHintUtils.java b/infra/common/src/main/java/org/apache/shardingsphere/infra/hint/SQLHintUtils.java
index 4743634de0e..7541368ac6d 100644
--- a/infra/common/src/main/java/org/apache/shardingsphere/infra/hint/SQLHintUtils.java
+++ b/infra/common/src/main/java/org/apache/shardingsphere/infra/hint/SQLHintUtils.java
@@ -28,6 +28,7 @@ import java.util.HashSet;
import java.util.List;
import java.util.Map.Entry;
import java.util.Objects;
+import java.util.Optional;
import java.util.Properties;
/**
@@ -107,11 +108,11 @@ public final class SQLHintUtils {
* @param sql SQL
* @return Hint value context
*/
- public static HintValueContext extractHint(final String sql) {
- HintValueContext result = new HintValueContext();
+ public static Optional<HintValueContext> extractHint(final String sql) {
if (null == sql || !startWithHint(sql)) {
- return result;
+ return Optional.empty();
}
+ HintValueContext result = new HintValueContext();
String hintText = sql.substring(0, sql.indexOf(SQL_COMMENT_SUFFIX) + 2);
Properties hintProperties = SQLHintUtils.getSQLHintProps(hintText);
if (containsPropertyKey(hintProperties, SQLHintPropertiesKey.DATASOURCE_NAME_KEY)) {
@@ -141,7 +142,7 @@ public final class SQLHintUtils {
result.getShardingTableValues().put(Objects.toString(entry.getKey()).toUpperCase(), value);
}
}
- return result;
+ return Optional.of(result);
}
private static boolean containsPropertyKey(final Properties hintProperties, final SQLHintPropertiesKey sqlHintPropertiesKey) {
diff --git a/infra/common/src/test/java/org/apache/shardingsphere/infra/hint/SQLHintExtractorTest.java b/infra/common/src/test/java/org/apache/shardingsphere/infra/hint/SQLHintExtractorTest.java
index c562ea3e60f..ea3a00244a0 100644
--- a/infra/common/src/test/java/org/apache/shardingsphere/infra/hint/SQLHintExtractorTest.java
+++ b/infra/common/src/test/java/org/apache/shardingsphere/infra/hint/SQLHintExtractorTest.java
@@ -43,11 +43,6 @@ class SQLHintExtractorTest {
assertTrue(new SQLHintExtractor(statement).isHintWriteRouteOnly());
}
- @Test
- void assertSQLHintWriteRouteOnlyWithCommentString() {
- assertTrue(new SQLHintExtractor("/* SHARDINGSPHERE_HINT: WRITE_ROUTE_ONLY=true */").isHintWriteRouteOnly());
- }
-
@Test
void assertSQLHintSkipSQLRewrite() {
AbstractSQLStatement statement = mock(AbstractSQLStatement.class);
@@ -144,7 +139,7 @@ class SQLHintExtractorTest {
void assertFindHintDataSourceNameExist() {
AbstractSQLStatement statement = mock(AbstractSQLStatement.class);
when(statement.getCommentSegments()).thenReturn(Collections.singletonList(new CommentSegment("/* SHARDINGSPHERE_HINT: DATA_SOURCE_NAME=ds_1 */", 0, 0)));
- Optional<String> dataSourceName = new SQLHintExtractor(statement).findHintDataSourceName();
+ Optional<String> dataSourceName = new SQLHintExtractor(statement).getHintValueContext().findHintDataSourceName();
assertTrue(dataSourceName.isPresent());
assertThat(dataSourceName.get(), is("ds_1"));
}
@@ -153,7 +148,7 @@ class SQLHintExtractorTest {
void assertFindHintDataSourceNameAliasExist() {
AbstractSQLStatement statement = mock(AbstractSQLStatement.class);
when(statement.getCommentSegments()).thenReturn(Collections.singletonList(new CommentSegment("/* ShardingSphere hint: dataSourceName=ds_1 */", 0, 0)));
- Optional<String> dataSourceName = new SQLHintExtractor(statement).findHintDataSourceName();
+ Optional<String> dataSourceName = new SQLHintExtractor(statement).getHintValueContext().findHintDataSourceName();
assertTrue(dataSourceName.isPresent());
assertThat(dataSourceName.get(), is("ds_1"));
}
@@ -162,14 +157,14 @@ class SQLHintExtractorTest {
void assertFindHintDataSourceNameNotExist() {
AbstractSQLStatement statement = mock(AbstractSQLStatement.class);
when(statement.getCommentSegments()).thenReturn(Collections.singletonList(new CommentSegment("/* no hint */", 0, 0)));
- Optional<String> dataSourceName = new SQLHintExtractor(statement).findHintDataSourceName();
+ Optional<String> dataSourceName = new SQLHintExtractor(statement).getHintValueContext().findHintDataSourceName();
assertFalse(dataSourceName.isPresent());
}
@Test
void assertFindHintDataSourceNameNotExistWithoutComment() {
AbstractSQLStatement statement = mock(AbstractSQLStatement.class);
- Optional<String> dataSourceName = new SQLHintExtractor(statement).findHintDataSourceName();
+ Optional<String> dataSourceName = new SQLHintExtractor(statement).getHintValueContext().findHintDataSourceName();
assertFalse(dataSourceName.isPresent());
}
}
diff --git a/infra/common/src/test/java/org/apache/shardingsphere/infra/hint/SQLHintUtilsTest.java b/infra/common/src/test/java/org/apache/shardingsphere/infra/hint/SQLHintUtilsTest.java
index 9bb4d117ad6..953acac4360 100644
--- a/infra/common/src/test/java/org/apache/shardingsphere/infra/hint/SQLHintUtilsTest.java
+++ b/infra/common/src/test/java/org/apache/shardingsphere/infra/hint/SQLHintUtilsTest.java
@@ -21,6 +21,7 @@ import org.junit.jupiter.api.Test;
import java.util.Arrays;
import java.util.Collection;
+import java.util.Optional;
import java.util.Properties;
import static org.hamcrest.CoreMatchers.is;
@@ -77,4 +78,11 @@ class SQLHintUtilsTest {
assertThat(actual.size(), is(1));
assertThat(actual.get("dataSourceName"), is("ds_0"));
}
+
+ @Test
+ void assertSQLHintWriteRouteOnlyWithCommentString() {
+ Optional<HintValueContext> actual = SQLHintUtils.extractHint("/* SHARDINGSPHERE_HINT: WRITE_ROUTE_ONLY=true */");
+ assertTrue(actual.isPresent());
+ assertTrue(actual.get().isWriteRouteOnly());
+ }
}
diff --git a/infra/context/src/main/java/org/apache/shardingsphere/infra/connection/kernel/KernelProcessor.java b/infra/context/src/main/java/org/apache/shardingsphere/infra/connection/kernel/KernelProcessor.java
index 4528d58e3f0..20741a495ad 100644
--- a/infra/context/src/main/java/org/apache/shardingsphere/infra/connection/kernel/KernelProcessor.java
+++ b/infra/context/src/main/java/org/apache/shardingsphere/infra/connection/kernel/KernelProcessor.java
@@ -17,10 +17,8 @@
package org.apache.shardingsphere.infra.connection.kernel;
-import org.apache.shardingsphere.infra.session.query.QueryContext;
import org.apache.shardingsphere.infra.config.props.ConfigurationProperties;
import org.apache.shardingsphere.infra.config.props.ConfigurationPropertyKey;
-import org.apache.shardingsphere.infra.session.connection.ConnectionContext;
import org.apache.shardingsphere.infra.executor.sql.context.ExecutionContext;
import org.apache.shardingsphere.infra.executor.sql.context.ExecutionContextBuilder;
import org.apache.shardingsphere.infra.executor.sql.log.SQLLogger;
@@ -30,6 +28,8 @@ import org.apache.shardingsphere.infra.rewrite.SQLRewriteEntry;
import org.apache.shardingsphere.infra.rewrite.engine.result.SQLRewriteResult;
import org.apache.shardingsphere.infra.route.context.RouteContext;
import org.apache.shardingsphere.infra.route.engine.SQLRouteEngine;
+import org.apache.shardingsphere.infra.session.connection.ConnectionContext;
+import org.apache.shardingsphere.infra.session.query.QueryContext;
/**
* Kernel processor.
@@ -63,7 +63,7 @@ public final class KernelProcessor {
private SQLRewriteResult rewrite(final QueryContext queryContext, final ShardingSphereDatabase database, final ShardingSphereRuleMetaData globalRuleMetaData,
final ConfigurationProperties props, final RouteContext routeContext, final ConnectionContext connectionContext) {
SQLRewriteEntry sqlRewriteEntry = new SQLRewriteEntry(database, globalRuleMetaData, props);
- return sqlRewriteEntry.rewrite(queryContext.getSql(), queryContext.getParameters(), queryContext.getSqlStatementContext(), routeContext, connectionContext);
+ return sqlRewriteEntry.rewrite(queryContext.getSql(), queryContext.getParameters(), queryContext.getSqlStatementContext(), routeContext, connectionContext, queryContext.getHintValueContext());
}
private ExecutionContext createExecutionContext(final QueryContext queryContext, final ShardingSphereDatabase database, final RouteContext routeContext, final SQLRewriteResult rewriteResult) {
diff --git a/infra/executor/src/main/java/org/apache/shardingsphere/infra/executor/audit/SQLAuditEngine.java b/infra/executor/src/main/java/org/apache/shardingsphere/infra/executor/audit/SQLAuditEngine.java
index c7b7016ccf6..4e36f90bc62 100644
--- a/infra/executor/src/main/java/org/apache/shardingsphere/infra/executor/audit/SQLAuditEngine.java
+++ b/infra/executor/src/main/java/org/apache/shardingsphere/infra/executor/audit/SQLAuditEngine.java
@@ -20,6 +20,7 @@ package org.apache.shardingsphere.infra.executor.audit;
import lombok.AccessLevel;
import lombok.NoArgsConstructor;
import org.apache.shardingsphere.infra.binder.statement.SQLStatementContext;
+import org.apache.shardingsphere.infra.hint.HintValueContext;
import org.apache.shardingsphere.infra.metadata.database.ShardingSphereDatabase;
import org.apache.shardingsphere.infra.metadata.database.rule.ShardingSphereRuleMetaData;
import org.apache.shardingsphere.infra.metadata.user.Grantee;
@@ -45,16 +46,17 @@ public final class SQLAuditEngine {
* @param globalRuleMetaData global rule meta data
* @param database database
* @param grantee grantee
+ * @param hintValueContext hint value context
*/
@SuppressWarnings({"rawtypes", "unchecked"})
public static void audit(final SQLStatementContext sqlStatementContext, final List<Object> params,
- final ShardingSphereRuleMetaData globalRuleMetaData, final ShardingSphereDatabase database, final Grantee grantee) {
+ final ShardingSphereRuleMetaData globalRuleMetaData, final ShardingSphereDatabase database, final Grantee grantee, final HintValueContext hintValueContext) {
Collection<ShardingSphereRule> rules = new LinkedList<>(globalRuleMetaData.getRules());
if (null != database) {
rules.addAll(database.getRuleMetaData().getRules());
}
for (Entry<ShardingSphereRule, SQLAuditor> entry : OrderedSPILoader.getServices(SQLAuditor.class, rules).entrySet()) {
- entry.getValue().audit(sqlStatementContext, params, grantee, globalRuleMetaData, database, entry.getKey());
+ entry.getValue().audit(sqlStatementContext, params, grantee, globalRuleMetaData, database, entry.getKey(), hintValueContext);
}
}
}
diff --git a/infra/executor/src/main/java/org/apache/shardingsphere/infra/executor/audit/SQLAuditor.java b/infra/executor/src/main/java/org/apache/shardingsphere/infra/executor/audit/SQLAuditor.java
index 82b68693504..93410038673 100644
--- a/infra/executor/src/main/java/org/apache/shardingsphere/infra/executor/audit/SQLAuditor.java
+++ b/infra/executor/src/main/java/org/apache/shardingsphere/infra/executor/audit/SQLAuditor.java
@@ -18,6 +18,7 @@
package org.apache.shardingsphere.infra.executor.audit;
import org.apache.shardingsphere.infra.binder.statement.SQLStatementContext;
+import org.apache.shardingsphere.infra.hint.HintValueContext;
import org.apache.shardingsphere.infra.metadata.database.ShardingSphereDatabase;
import org.apache.shardingsphere.infra.metadata.database.rule.ShardingSphereRuleMetaData;
import org.apache.shardingsphere.infra.metadata.user.Grantee;
@@ -44,6 +45,8 @@ public interface SQLAuditor<T extends ShardingSphereRule> extends OrderedSPI<T>
* @param globalRuleMetaData global rule meta data
* @param database current database
* @param rule rule
+ * @param hintValueContext hint value context
*/
- void audit(SQLStatementContext sqlStatementContext, List<Object> params, Grantee grantee, ShardingSphereRuleMetaData globalRuleMetaData, ShardingSphereDatabase database, T rule);
+ void audit(SQLStatementContext sqlStatementContext, List<Object> params, Grantee grantee, ShardingSphereRuleMetaData globalRuleMetaData,
+ ShardingSphereDatabase database, T rule, HintValueContext hintValueContext);
}
diff --git a/infra/rewrite/src/main/java/org/apache/shardingsphere/infra/rewrite/SQLRewriteEntry.java b/infra/rewrite/src/main/java/org/apache/shardingsphere/infra/rewrite/SQLRewriteEntry.java
index a5a525198f6..a9eb260f535 100644
--- a/infra/rewrite/src/main/java/org/apache/shardingsphere/infra/rewrite/SQLRewriteEntry.java
+++ b/infra/rewrite/src/main/java/org/apache/shardingsphere/infra/rewrite/SQLRewriteEntry.java
@@ -17,11 +17,10 @@
package org.apache.shardingsphere.infra.rewrite;
-import org.apache.shardingsphere.infra.binder.statement.CommonSQLStatementContext;
import org.apache.shardingsphere.infra.binder.statement.SQLStatementContext;
import org.apache.shardingsphere.infra.config.props.ConfigurationProperties;
-import org.apache.shardingsphere.infra.session.connection.ConnectionContext;
import org.apache.shardingsphere.infra.database.type.DatabaseType;
+import org.apache.shardingsphere.infra.hint.HintValueContext;
import org.apache.shardingsphere.infra.metadata.database.ShardingSphereDatabase;
import org.apache.shardingsphere.infra.metadata.database.rule.ShardingSphereRuleMetaData;
import org.apache.shardingsphere.infra.rewrite.context.SQLRewriteContext;
@@ -31,6 +30,7 @@ import org.apache.shardingsphere.infra.rewrite.engine.RouteSQLRewriteEngine;
import org.apache.shardingsphere.infra.rewrite.engine.result.SQLRewriteResult;
import org.apache.shardingsphere.infra.route.context.RouteContext;
import org.apache.shardingsphere.infra.rule.ShardingSphereRule;
+import org.apache.shardingsphere.infra.session.connection.ConnectionContext;
import org.apache.shardingsphere.infra.util.spi.type.ordered.OrderedSPILoader;
import org.apache.shardingsphere.sqltranslator.rule.SQLTranslatorRule;
@@ -67,11 +67,13 @@ public final class SQLRewriteEntry {
* @param sqlStatementContext SQL statement context
* @param routeContext route context
* @param connectionContext connection context
+ * @param hintValueContext hint value context
+ *
* @return route unit and SQL rewrite result map
*/
public SQLRewriteResult rewrite(final String sql, final List<Object> params, final SQLStatementContext sqlStatementContext,
- final RouteContext routeContext, final ConnectionContext connectionContext) {
- SQLRewriteContext sqlRewriteContext = createSQLRewriteContext(sql, params, sqlStatementContext, routeContext, connectionContext);
+ final RouteContext routeContext, final ConnectionContext connectionContext, final HintValueContext hintValueContext) {
+ SQLRewriteContext sqlRewriteContext = createSQLRewriteContext(sql, params, sqlStatementContext, routeContext, connectionContext, hintValueContext);
SQLTranslatorRule rule = globalRuleMetaData.getSingleRule(SQLTranslatorRule.class);
DatabaseType protocolType = database.getProtocolType();
Map<String, DatabaseType> storageTypes = database.getResourceMetaData().getStorageTypes();
@@ -81,16 +83,17 @@ public final class SQLRewriteEntry {
}
private SQLRewriteContext createSQLRewriteContext(final String sql, final List<Object> params, final SQLStatementContext sqlStatementContext,
- final RouteContext routeContext, final ConnectionContext connectionContext) {
- SQLRewriteContext result = new SQLRewriteContext(database.getName(), database.getSchemas(), sqlStatementContext, sql, params, connectionContext);
- decorate(decorators, result, routeContext);
+ final RouteContext routeContext, final ConnectionContext connectionContext, final HintValueContext hintValueContext) {
+ SQLRewriteContext result = new SQLRewriteContext(database.getName(), database.getSchemas(), sqlStatementContext, sql, params, connectionContext, hintValueContext);
+ decorate(decorators, result, routeContext, hintValueContext);
result.generateSQLTokens();
return result;
}
@SuppressWarnings({"unchecked", "rawtypes"})
- private void decorate(final Map<ShardingSphereRule, SQLRewriteContextDecorator> decorators, final SQLRewriteContext sqlRewriteContext, final RouteContext routeContext) {
- if (((CommonSQLStatementContext) sqlRewriteContext.getSqlStatementContext()).isHintSkipSQLRewrite()) {
+ private void decorate(final Map<ShardingSphereRule, SQLRewriteContextDecorator> decorators, final SQLRewriteContext sqlRewriteContext,
+ final RouteContext routeContext, final HintValueContext hintValueContext) {
+ if (hintValueContext.isSkipSQLRewrite()) {
return;
}
for (Entry<ShardingSphereRule, SQLRewriteContextDecorator> entry : decorators.entrySet()) {
diff --git a/infra/rewrite/src/main/java/org/apache/shardingsphere/infra/rewrite/context/SQLRewriteContext.java b/infra/rewrite/src/main/java/org/apache/shardingsphere/infra/rewrite/context/SQLRewriteContext.java
index 90a80ce961c..af081b4468c 100644
--- a/infra/rewrite/src/main/java/org/apache/shardingsphere/infra/rewrite/context/SQLRewriteContext.java
+++ b/infra/rewrite/src/main/java/org/apache/shardingsphere/infra/rewrite/context/SQLRewriteContext.java
@@ -19,10 +19,9 @@ package org.apache.shardingsphere.infra.rewrite.context;
import lombok.AccessLevel;
import lombok.Getter;
-import org.apache.shardingsphere.infra.binder.statement.CommonSQLStatementContext;
import org.apache.shardingsphere.infra.binder.statement.SQLStatementContext;
import org.apache.shardingsphere.infra.binder.statement.dml.InsertStatementContext;
-import org.apache.shardingsphere.infra.session.connection.ConnectionContext;
+import org.apache.shardingsphere.infra.hint.HintValueContext;
import org.apache.shardingsphere.infra.metadata.database.schema.model.ShardingSphereSchema;
import org.apache.shardingsphere.infra.rewrite.parameter.builder.ParameterBuilder;
import org.apache.shardingsphere.infra.rewrite.parameter.builder.impl.GroupedParameterBuilder;
@@ -31,6 +30,7 @@ import org.apache.shardingsphere.infra.rewrite.sql.token.generator.SQLTokenGener
import org.apache.shardingsphere.infra.rewrite.sql.token.generator.SQLTokenGenerators;
import org.apache.shardingsphere.infra.rewrite.sql.token.generator.builder.DefaultTokenGeneratorBuilder;
import org.apache.shardingsphere.infra.rewrite.sql.token.pojo.SQLToken;
+import org.apache.shardingsphere.infra.session.connection.ConnectionContext;
import java.util.Collection;
import java.util.LinkedList;
@@ -62,15 +62,15 @@ public final class SQLRewriteContext {
private final ConnectionContext connectionContext;
- public SQLRewriteContext(final String databaseName, final Map<String, ShardingSphereSchema> schemas,
- final SQLStatementContext sqlStatementContext, final String sql, final List<Object> params, final ConnectionContext connectionContext) {
+ public SQLRewriteContext(final String databaseName, final Map<String, ShardingSphereSchema> schemas, final SQLStatementContext sqlStatementContext,
+ final String sql, final List<Object> params, final ConnectionContext connectionContext, final HintValueContext hintValueContext) {
this.databaseName = databaseName;
this.schemas = schemas;
this.sqlStatementContext = sqlStatementContext;
this.sql = sql;
parameters = params;
this.connectionContext = connectionContext;
- if (!((CommonSQLStatementContext) sqlStatementContext).isHintSkipSQLRewrite()) {
+ if (!hintValueContext.isSkipSQLRewrite()) {
addSQLTokenGenerators(new DefaultTokenGeneratorBuilder(sqlStatementContext).getSQLTokenGenerators());
}
parameterBuilder = sqlStatementContext instanceof InsertStatementContext && null == ((InsertStatementContext) sqlStatementContext).getInsertSelectContext()
diff --git a/infra/rewrite/src/test/java/org/apache/shardingsphere/infra/rewrite/SQLRewriteEntryTest.java b/infra/rewrite/src/test/java/org/apache/shardingsphere/infra/rewrite/SQLRewriteEntryTest.java
index 06426ea03ed..d125d5fddbf 100644
--- a/infra/rewrite/src/test/java/org/apache/shardingsphere/infra/rewrite/SQLRewriteEntryTest.java
+++ b/infra/rewrite/src/test/java/org/apache/shardingsphere/infra/rewrite/SQLRewriteEntryTest.java
@@ -19,11 +19,11 @@ package org.apache.shardingsphere.infra.rewrite;
import org.apache.shardingsphere.infra.binder.statement.CommonSQLStatementContext;
import org.apache.shardingsphere.infra.config.props.ConfigurationProperties;
-import org.apache.shardingsphere.infra.session.connection.ConnectionContext;
import org.apache.shardingsphere.infra.database.DefaultDatabase;
import org.apache.shardingsphere.infra.database.type.DatabaseType;
import org.apache.shardingsphere.infra.database.type.dialect.H2DatabaseType;
import org.apache.shardingsphere.infra.database.type.dialect.MySQLDatabaseType;
+import org.apache.shardingsphere.infra.hint.HintValueContext;
import org.apache.shardingsphere.infra.metadata.database.ShardingSphereDatabase;
import org.apache.shardingsphere.infra.metadata.database.resource.ShardingSphereResourceMetaData;
import org.apache.shardingsphere.infra.metadata.database.rule.ShardingSphereRuleMetaData;
@@ -33,6 +33,7 @@ import org.apache.shardingsphere.infra.rewrite.engine.result.RouteSQLRewriteResu
import org.apache.shardingsphere.infra.route.context.RouteContext;
import org.apache.shardingsphere.infra.route.context.RouteMapper;
import org.apache.shardingsphere.infra.route.context.RouteUnit;
+import org.apache.shardingsphere.infra.session.connection.ConnectionContext;
import org.apache.shardingsphere.infra.util.spi.type.typed.TypedSPILoader;
import org.apache.shardingsphere.sqltranslator.api.config.SQLTranslatorRuleConfiguration;
import org.apache.shardingsphere.sqltranslator.rule.SQLTranslatorRule;
@@ -59,7 +60,7 @@ class SQLRewriteEntryTest {
database, new ShardingSphereRuleMetaData(Collections.singleton(new SQLTranslatorRule(new SQLTranslatorRuleConfiguration()))), new ConfigurationProperties(new Properties()));
RouteContext routeContext = new RouteContext();
GenericSQLRewriteResult sqlRewriteResult = (GenericSQLRewriteResult) sqlRewriteEntry.rewrite("SELECT ?", Collections.singletonList(1), mock(CommonSQLStatementContext.class), routeContext,
- mock(ConnectionContext.class));
+ mock(ConnectionContext.class), new HintValueContext());
assertThat(sqlRewriteResult.getSqlRewriteUnit().getSql(), is("SELECT ?"));
assertThat(sqlRewriteResult.getSqlRewriteUnit().getParameters(), is(Collections.singletonList(1)));
}
@@ -77,7 +78,7 @@ class SQLRewriteEntryTest {
when(secondRouteUnit.getDataSourceMapper()).thenReturn(new RouteMapper("ds", "ds_1"));
routeContext.getRouteUnits().addAll(Arrays.asList(firstRouteUnit, secondRouteUnit));
RouteSQLRewriteResult sqlRewriteResult = (RouteSQLRewriteResult) sqlRewriteEntry.rewrite("SELECT ?", Collections.singletonList(1), mock(CommonSQLStatementContext.class), routeContext,
- mock(ConnectionContext.class));
+ mock(ConnectionContext.class), new HintValueContext());
assertThat(sqlRewriteResult.getSqlRewriteUnits().size(), is(2));
}
diff --git a/infra/rewrite/src/test/java/org/apache/shardingsphere/infra/rewrite/context/SQLRewriteContextTest.java b/infra/rewrite/src/test/java/org/apache/shardingsphere/infra/rewrite/context/SQLRewriteContextTest.java
index f4a63a2d5c7..4001ffdb354 100644
--- a/infra/rewrite/src/test/java/org/apache/shardingsphere/infra/rewrite/context/SQLRewriteContextTest.java
+++ b/infra/rewrite/src/test/java/org/apache/shardingsphere/infra/rewrite/context/SQLRewriteContextTest.java
@@ -21,14 +21,15 @@ import org.apache.shardingsphere.infra.binder.statement.CommonSQLStatementContex
import org.apache.shardingsphere.infra.binder.statement.dml.InsertStatementContext;
import org.apache.shardingsphere.infra.binder.statement.dml.SelectStatementContext;
import org.apache.shardingsphere.infra.binder.type.TableAvailable;
-import org.apache.shardingsphere.infra.session.connection.ConnectionContext;
import org.apache.shardingsphere.infra.database.DefaultDatabase;
+import org.apache.shardingsphere.infra.hint.HintValueContext;
import org.apache.shardingsphere.infra.metadata.database.schema.model.ShardingSphereSchema;
import org.apache.shardingsphere.infra.rewrite.parameter.builder.impl.GroupedParameterBuilder;
import org.apache.shardingsphere.infra.rewrite.parameter.builder.impl.StandardParameterBuilder;
import org.apache.shardingsphere.infra.rewrite.sql.token.generator.CollectionSQLTokenGenerator;
import org.apache.shardingsphere.infra.rewrite.sql.token.generator.OptionalSQLTokenGenerator;
import org.apache.shardingsphere.infra.rewrite.sql.token.pojo.SQLToken;
+import org.apache.shardingsphere.infra.session.connection.ConnectionContext;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
@@ -62,6 +63,9 @@ class SQLRewriteContextTest {
@Mock
private CollectionSQLTokenGenerator collectionSQLTokenGenerator;
+ @Mock
+ private HintValueContext hintValueContext;
+
@SuppressWarnings("unchecked")
@BeforeEach
void setUp() {
@@ -74,8 +78,8 @@ class SQLRewriteContextTest {
InsertStatementContext statementContext = mock(InsertStatementContext.class, RETURNS_DEEP_STUBS);
when(((TableAvailable) statementContext).getTablesContext().getDatabaseName().isPresent()).thenReturn(false);
when(statementContext.getInsertSelectContext()).thenReturn(null);
- SQLRewriteContext sqlRewriteContext = new SQLRewriteContext(DefaultDatabase.LOGIC_NAME,
- Collections.singletonMap("test", mock(ShardingSphereSchema.class)), statementContext, "INSERT INTO tbl VALUES (?)", Collections.singletonList(1), mock(ConnectionContext.class));
+ SQLRewriteContext sqlRewriteContext = new SQLRewriteContext(DefaultDatabase.LOGIC_NAME, Collections.singletonMap("test", mock(ShardingSphereSchema.class)),
+ statementContext, "INSERT INTO tbl VALUES (?)", Collections.singletonList(1), mock(ConnectionContext.class), hintValueContext);
assertThat(sqlRewriteContext.getParameterBuilder(), instanceOf(GroupedParameterBuilder.class));
}
@@ -83,15 +87,15 @@ class SQLRewriteContextTest {
void assertNotInsertStatementContext() {
SelectStatementContext statementContext = mock(SelectStatementContext.class, RETURNS_DEEP_STUBS);
when(((TableAvailable) statementContext).getTablesContext().getDatabaseName().isPresent()).thenReturn(false);
- SQLRewriteContext sqlRewriteContext = new SQLRewriteContext(DefaultDatabase.LOGIC_NAME,
- Collections.singletonMap("test", mock(ShardingSphereSchema.class)), statementContext, "SELECT * FROM tbl WHERE id = ?", Collections.singletonList(1), mock(ConnectionContext.class));
+ SQLRewriteContext sqlRewriteContext = new SQLRewriteContext(DefaultDatabase.LOGIC_NAME, Collections.singletonMap("test", mock(ShardingSphereSchema.class)),
+ statementContext, "SELECT * FROM tbl WHERE id = ?", Collections.singletonList(1), mock(ConnectionContext.class), hintValueContext);
assertThat(sqlRewriteContext.getParameterBuilder(), instanceOf(StandardParameterBuilder.class));
}
@Test
void assertGenerateOptionalSQLToken() {
- SQLRewriteContext sqlRewriteContext = new SQLRewriteContext(DefaultDatabase.LOGIC_NAME,
- Collections.singletonMap("test", mock(ShardingSphereSchema.class)), sqlStatementContext, "INSERT INTO tbl VALUES (?)", Collections.singletonList(1), mock(ConnectionContext.class));
+ SQLRewriteContext sqlRewriteContext = new SQLRewriteContext(DefaultDatabase.LOGIC_NAME, Collections.singletonMap("test", mock(ShardingSphereSchema.class)),
+ sqlStatementContext, "INSERT INTO tbl VALUES (?)", Collections.singletonList(1), mock(ConnectionContext.class), hintValueContext);
sqlRewriteContext.addSQLTokenGenerators(Collections.singleton(optionalSQLTokenGenerator));
sqlRewriteContext.generateSQLTokens();
assertFalse(sqlRewriteContext.getSqlTokens().isEmpty());
@@ -100,8 +104,8 @@ class SQLRewriteContextTest {
@Test
void assertGenerateCollectionSQLToken() {
- SQLRewriteContext sqlRewriteContext = new SQLRewriteContext(DefaultDatabase.LOGIC_NAME,
- Collections.singletonMap("test", mock(ShardingSphereSchema.class)), sqlStatementContext, "INSERT INTO tbl VALUES (?)", Collections.singletonList(1), mock(ConnectionContext.class));
+ SQLRewriteContext sqlRewriteContext = new SQLRewriteContext(DefaultDatabase.LOGIC_NAME, Collections.singletonMap("test", mock(ShardingSphereSchema.class)),
+ sqlStatementContext, "INSERT INTO tbl VALUES (?)", Collections.singletonList(1), mock(ConnectionContext.class), hintValueContext);
sqlRewriteContext.addSQLTokenGenerators(Collections.singleton(collectionSQLTokenGenerator));
sqlRewriteContext.generateSQLTokens();
assertFalse(sqlRewriteContext.getSqlTokens().isEmpty());
diff --git a/infra/rewrite/src/test/java/org/apache/shardingsphere/infra/rewrite/engine/GenericSQLRewriteEngineTest.java b/infra/rewrite/src/test/java/org/apache/shardingsphere/infra/rewrite/engine/GenericSQLRewriteEngineTest.java
index 1ca53ad905c..b498c147bab 100644
--- a/infra/rewrite/src/test/java/org/apache/shardingsphere/infra/rewrite/engine/GenericSQLRewriteEngineTest.java
+++ b/infra/rewrite/src/test/java/org/apache/shardingsphere/infra/rewrite/engine/GenericSQLRewriteEngineTest.java
@@ -18,12 +18,13 @@
package org.apache.shardingsphere.infra.rewrite.engine;
import org.apache.shardingsphere.infra.binder.statement.CommonSQLStatementContext;
-import org.apache.shardingsphere.infra.session.connection.ConnectionContext;
import org.apache.shardingsphere.infra.database.DefaultDatabase;
import org.apache.shardingsphere.infra.database.type.DatabaseType;
+import org.apache.shardingsphere.infra.hint.HintValueContext;
import org.apache.shardingsphere.infra.metadata.database.schema.model.ShardingSphereSchema;
import org.apache.shardingsphere.infra.rewrite.context.SQLRewriteContext;
import org.apache.shardingsphere.infra.rewrite.engine.result.GenericSQLRewriteResult;
+import org.apache.shardingsphere.infra.session.connection.ConnectionContext;
import org.apache.shardingsphere.sqltranslator.api.config.SQLTranslatorRuleConfiguration;
import org.apache.shardingsphere.sqltranslator.rule.SQLTranslatorRule;
import org.junit.jupiter.api.Test;
@@ -41,7 +42,8 @@ class GenericSQLRewriteEngineTest {
DatabaseType databaseType = mock(DatabaseType.class);
SQLTranslatorRule rule = new SQLTranslatorRule(new SQLTranslatorRuleConfiguration());
GenericSQLRewriteResult actual = new GenericSQLRewriteEngine(rule, databaseType, Collections.singletonMap("ds_0", databaseType)).rewrite(new SQLRewriteContext(DefaultDatabase.LOGIC_NAME,
- Collections.singletonMap("test", mock(ShardingSphereSchema.class)), mock(CommonSQLStatementContext.class), "SELECT 1", Collections.emptyList(), mock(ConnectionContext.class)));
+ Collections.singletonMap("test", mock(ShardingSphereSchema.class)), mock(CommonSQLStatementContext.class), "SELECT 1", Collections.emptyList(), mock(ConnectionContext.class),
+ new HintValueContext()));
assertThat(actual.getSqlRewriteUnit().getSql(), is("SELECT 1"));
assertThat(actual.getSqlRewriteUnit().getParameters(), is(Collections.emptyList()));
}
@@ -50,7 +52,8 @@ class GenericSQLRewriteEngineTest {
void assertRewriteStorageTypeIsEmpty() {
SQLTranslatorRule rule = new SQLTranslatorRule(new SQLTranslatorRuleConfiguration());
GenericSQLRewriteResult actual = new GenericSQLRewriteEngine(rule, mock(DatabaseType.class), Collections.emptyMap()).rewrite(new SQLRewriteContext(DefaultDatabase.LOGIC_NAME,
- Collections.singletonMap("test", mock(ShardingSphereSchema.class)), mock(CommonSQLStatementContext.class), "SELECT 1", Collections.emptyList(), mock(ConnectionContext.class)));
+ Collections.singletonMap("test", mock(ShardingSphereSchema.class)), mock(CommonSQLStatementContext.class), "SELECT 1", Collections.emptyList(), mock(ConnectionContext.class),
+ new HintValueContext()));
assertThat(actual.getSqlRewriteUnit().getSql(), is("SELECT 1"));
assertThat(actual.getSqlRewriteUnit().getParameters(), is(Collections.emptyList()));
}
diff --git a/infra/rewrite/src/test/java/org/apache/shardingsphere/infra/rewrite/engine/RouteSQLRewriteEngineTest.java b/infra/rewrite/src/test/java/org/apache/shardingsphere/infra/rewrite/engine/RouteSQLRewriteEngineTest.java
index 3143e4bb6bb..57c700c22d1 100644
--- a/infra/rewrite/src/test/java/org/apache/shardingsphere/infra/rewrite/engine/RouteSQLRewriteEngineTest.java
+++ b/infra/rewrite/src/test/java/org/apache/shardingsphere/infra/rewrite/engine/RouteSQLRewriteEngineTest.java
@@ -21,16 +21,17 @@ import org.apache.shardingsphere.infra.binder.statement.CommonSQLStatementContex
import org.apache.shardingsphere.infra.binder.statement.dml.InsertStatementContext;
import org.apache.shardingsphere.infra.binder.statement.dml.SelectStatementContext;
import org.apache.shardingsphere.infra.binder.type.TableAvailable;
-import org.apache.shardingsphere.infra.session.connection.ConnectionContext;
import org.apache.shardingsphere.infra.database.DefaultDatabase;
import org.apache.shardingsphere.infra.database.type.DatabaseType;
import org.apache.shardingsphere.infra.datanode.DataNode;
+import org.apache.shardingsphere.infra.hint.HintValueContext;
import org.apache.shardingsphere.infra.metadata.database.schema.model.ShardingSphereSchema;
import org.apache.shardingsphere.infra.rewrite.context.SQLRewriteContext;
import org.apache.shardingsphere.infra.rewrite.engine.result.RouteSQLRewriteResult;
import org.apache.shardingsphere.infra.route.context.RouteContext;
import org.apache.shardingsphere.infra.route.context.RouteMapper;
import org.apache.shardingsphere.infra.route.context.RouteUnit;
+import org.apache.shardingsphere.infra.session.connection.ConnectionContext;
import org.apache.shardingsphere.sqltranslator.api.config.SQLTranslatorRuleConfiguration;
import org.apache.shardingsphere.sqltranslator.rule.SQLTranslatorRule;
import org.junit.jupiter.api.Test;
@@ -49,8 +50,8 @@ class RouteSQLRewriteEngineTest {
@Test
void assertRewriteWithStandardParameterBuilder() {
- SQLRewriteContext sqlRewriteContext = new SQLRewriteContext(DefaultDatabase.LOGIC_NAME,
- Collections.singletonMap("test", mock(ShardingSphereSchema.class)), mock(CommonSQLStatementContext.class), "SELECT ?", Collections.singletonList(1), mock(ConnectionContext.class));
+ SQLRewriteContext sqlRewriteContext = new SQLRewriteContext(DefaultDatabase.LOGIC_NAME, Collections.singletonMap("test", mock(ShardingSphereSchema.class)),
+ mock(CommonSQLStatementContext.class), "SELECT ?", Collections.singletonList(1), mock(ConnectionContext.class), new HintValueContext());
RouteUnit routeUnit = new RouteUnit(new RouteMapper("ds", "ds_0"), Collections.singletonList(new RouteMapper("tbl", "tbl_0")));
RouteContext routeContext = new RouteContext();
routeContext.getRouteUnits().add(routeUnit);
@@ -67,8 +68,8 @@ class RouteSQLRewriteEngineTest {
SelectStatementContext statementContext = mock(SelectStatementContext.class, RETURNS_DEEP_STUBS);
when(statementContext.getOrderByContext().getItems()).thenReturn(Collections.emptyList());
when(statementContext.getPaginationContext().isHasPagination()).thenReturn(false);
- SQLRewriteContext sqlRewriteContext = new SQLRewriteContext(DefaultDatabase.LOGIC_NAME,
- Collections.singletonMap("test", mock(ShardingSphereSchema.class)), statementContext, "SELECT ?", Collections.singletonList(1), mock(ConnectionContext.class));
+ SQLRewriteContext sqlRewriteContext = new SQLRewriteContext(DefaultDatabase.LOGIC_NAME, Collections.singletonMap("test", mock(ShardingSphereSchema.class)),
+ statementContext, "SELECT ?", Collections.singletonList(1), mock(ConnectionContext.class), new HintValueContext());
RouteContext routeContext = new RouteContext();
RouteUnit firstRouteUnit = new RouteUnit(new RouteMapper("ds", "ds_0"), Collections.singletonList(new RouteMapper("tbl", "tbl_0")));
RouteUnit secondRouteUnit = new RouteUnit(new RouteMapper("ds", "ds_0"), Collections.singletonList(new RouteMapper("tbl", "tbl_1")));
@@ -87,8 +88,8 @@ class RouteSQLRewriteEngineTest {
InsertStatementContext statementContext = mock(InsertStatementContext.class, RETURNS_DEEP_STUBS);
when(((TableAvailable) statementContext).getTablesContext().getDatabaseName().isPresent()).thenReturn(false);
when(statementContext.getGroupedParameters()).thenReturn(Collections.singletonList(Collections.singletonList(1)));
- SQLRewriteContext sqlRewriteContext = new SQLRewriteContext(DefaultDatabase.LOGIC_NAME,
- Collections.singletonMap("test", mock(ShardingSphereSchema.class)), statementContext, "INSERT INTO tbl VALUES (?)", Collections.singletonList(1), mock(ConnectionContext.class));
+ SQLRewriteContext sqlRewriteContext = new SQLRewriteContext(DefaultDatabase.LOGIC_NAME, Collections.singletonMap("test", mock(ShardingSphereSchema.class)),
+ statementContext, "INSERT INTO tbl VALUES (?)", Collections.singletonList(1), mock(ConnectionContext.class), new HintValueContext());
RouteUnit routeUnit = new RouteUnit(new RouteMapper("ds", "ds_0"), Collections.singletonList(new RouteMapper("tbl", "tbl_0")));
RouteContext routeContext = new RouteContext();
routeContext.getRouteUnits().add(routeUnit);
@@ -105,8 +106,8 @@ class RouteSQLRewriteEngineTest {
InsertStatementContext statementContext = mock(InsertStatementContext.class, RETURNS_DEEP_STUBS);
when(((TableAvailable) statementContext).getTablesContext().getDatabaseName().isPresent()).thenReturn(false);
when(statementContext.getGroupedParameters()).thenReturn(Collections.singletonList(Collections.singletonList(1)));
- SQLRewriteContext sqlRewriteContext = new SQLRewriteContext(DefaultDatabase.LOGIC_NAME,
- Collections.singletonMap("test", mock(ShardingSphereSchema.class)), statementContext, "INSERT INTO tbl VALUES (?)", Collections.singletonList(1), mock(ConnectionContext.class));
+ SQLRewriteContext sqlRewriteContext = new SQLRewriteContext(DefaultDatabase.LOGIC_NAME, Collections.singletonMap("test", mock(ShardingSphereSchema.class)),
+ statementContext, "INSERT INTO tbl VALUES (?)", Collections.singletonList(1), mock(ConnectionContext.class), new HintValueContext());
RouteUnit routeUnit = new RouteUnit(new RouteMapper("ds", "ds_0"), Collections.singletonList(new RouteMapper("tbl", "tbl_0")));
RouteContext routeContext = new RouteContext();
routeContext.getRouteUnits().add(routeUnit);
@@ -125,8 +126,8 @@ class RouteSQLRewriteEngineTest {
InsertStatementContext statementContext = mock(InsertStatementContext.class, RETURNS_DEEP_STUBS);
when(((TableAvailable) statementContext).getTablesContext().getDatabaseName().isPresent()).thenReturn(false);
when(statementContext.getGroupedParameters()).thenReturn(Collections.singletonList(Collections.singletonList(1)));
- SQLRewriteContext sqlRewriteContext = new SQLRewriteContext(DefaultDatabase.LOGIC_NAME,
- Collections.singletonMap("test", mock(ShardingSphereSchema.class)), statementContext, "INSERT INTO tbl VALUES (?)", Collections.singletonList(1), mock(ConnectionContext.class));
+ SQLRewriteContext sqlRewriteContext = new SQLRewriteContext(DefaultDatabase.LOGIC_NAME, Collections.singletonMap("test", mock(ShardingSphereSchema.class)),
+ statementContext, "INSERT INTO tbl VALUES (?)", Collections.singletonList(1), mock(ConnectionContext.class), new HintValueContext());
RouteUnit routeUnit = new RouteUnit(new RouteMapper("ds", "ds_0"), Collections.singletonList(new RouteMapper("tbl", "tbl_0")));
RouteContext routeContext = new RouteContext();
routeContext.getRouteUnits().add(routeUnit);
@@ -146,8 +147,8 @@ class RouteSQLRewriteEngineTest {
when(statementContext.getInsertSelectContext()).thenReturn(null);
when(statementContext.getGroupedParameters()).thenReturn(Collections.singletonList(Collections.singletonList(1)));
when(statementContext.getOnDuplicateKeyUpdateParameters()).thenReturn(Collections.emptyList());
- SQLRewriteContext sqlRewriteContext = new SQLRewriteContext(DefaultDatabase.LOGIC_NAME,
- Collections.singletonMap("test", mock(ShardingSphereSchema.class)), statementContext, "INSERT INTO tbl VALUES (?)", Collections.singletonList(1), mock(ConnectionContext.class));
+ SQLRewriteContext sqlRewriteContext = new SQLRewriteContext(DefaultDatabase.LOGIC_NAME, Collections.singletonMap("test", mock(ShardingSphereSchema.class)),
+ statementContext, "INSERT INTO tbl VALUES (?)", Collections.singletonList(1), mock(ConnectionContext.class), new HintValueContext());
RouteUnit routeUnit = new RouteUnit(new RouteMapper("ds", "ds_0"), Collections.singletonList(new RouteMapper("tbl", "tbl_0")));
RouteContext routeContext = new RouteContext();
routeContext.getRouteUnits().add(routeUnit);
diff --git a/infra/route/src/main/java/org/apache/shardingsphere/infra/route/engine/impl/PartialSQLRouteExecutor.java b/infra/route/src/main/java/org/apache/shardingsphere/infra/route/engine/impl/PartialSQLRouteExecutor.java
index 957fae229a3..21f45c85633 100644
--- a/infra/route/src/main/java/org/apache/shardingsphere/infra/route/engine/impl/PartialSQLRouteExecutor.java
+++ b/infra/route/src/main/java/org/apache/shardingsphere/infra/route/engine/impl/PartialSQLRouteExecutor.java
@@ -17,12 +17,9 @@
package org.apache.shardingsphere.infra.route.engine.impl;
-import org.apache.shardingsphere.infra.session.query.QueryContext;
-import org.apache.shardingsphere.infra.binder.statement.CommonSQLStatementContext;
-import org.apache.shardingsphere.infra.binder.statement.SQLStatementContext;
import org.apache.shardingsphere.infra.config.props.ConfigurationProperties;
-import org.apache.shardingsphere.infra.session.connection.ConnectionContext;
import org.apache.shardingsphere.infra.hint.HintManager;
+import org.apache.shardingsphere.infra.hint.HintValueContext;
import org.apache.shardingsphere.infra.hint.SQLHintDataSourceNotExistsException;
import org.apache.shardingsphere.infra.metadata.database.ShardingSphereDatabase;
import org.apache.shardingsphere.infra.metadata.database.rule.ShardingSphereRuleMetaData;
@@ -32,6 +29,8 @@ import org.apache.shardingsphere.infra.route.context.RouteMapper;
import org.apache.shardingsphere.infra.route.context.RouteUnit;
import org.apache.shardingsphere.infra.route.engine.SQLRouteExecutor;
import org.apache.shardingsphere.infra.rule.ShardingSphereRule;
+import org.apache.shardingsphere.infra.session.connection.ConnectionContext;
+import org.apache.shardingsphere.infra.session.query.QueryContext;
import org.apache.shardingsphere.infra.util.spi.type.ordered.OrderedSPILoader;
import javax.sql.DataSource;
@@ -60,7 +59,7 @@ public final class PartialSQLRouteExecutor implements SQLRouteExecutor {
@SuppressWarnings({"unchecked", "rawtypes"})
public RouteContext route(final ConnectionContext connectionContext, final QueryContext queryContext, final ShardingSphereRuleMetaData globalRuleMetaData, final ShardingSphereDatabase database) {
RouteContext result = new RouteContext();
- Optional<String> dataSourceName = findDataSourceByHint(queryContext.getSqlStatementContext(), database.getResourceMetaData().getDataSources());
+ Optional<String> dataSourceName = findDataSourceByHint(queryContext.getHintValueContext(), database.getResourceMetaData().getDataSources());
if (dataSourceName.isPresent()) {
result.getRouteUnits().add(new RouteUnit(new RouteMapper(dataSourceName.get(), dataSourceName.get()), Collections.emptyList()));
return result;
@@ -79,12 +78,12 @@ public final class PartialSQLRouteExecutor implements SQLRouteExecutor {
return result;
}
- private Optional<String> findDataSourceByHint(final SQLStatementContext sqlStatementContext, final Map<String, DataSource> dataSources) {
+ private Optional<String> findDataSourceByHint(final HintValueContext hintValueContext, final Map<String, DataSource> dataSources) {
Optional<String> result;
if (HintManager.isInstantiated() && HintManager.getDataSourceName().isPresent()) {
result = HintManager.getDataSourceName();
} else {
- result = ((CommonSQLStatementContext) sqlStatementContext).findHintDataSourceName();
+ result = hintValueContext.findHintDataSourceName();
}
if (result.isPresent() && !dataSources.containsKey(result.get())) {
throw new SQLHintDataSourceNotExistsException(result.get());
diff --git a/infra/route/src/test/java/org/apache/shardingsphere/infra/route/engine/impl/PartialSQLRouteExecutorTest.java b/infra/route/src/test/java/org/apache/shardingsphere/infra/route/engine/impl/PartialSQLRouteExecutorTest.java
index 7ec05d27279..53ea30051a3 100644
--- a/infra/route/src/test/java/org/apache/shardingsphere/infra/route/engine/impl/PartialSQLRouteExecutorTest.java
+++ b/infra/route/src/test/java/org/apache/shardingsphere/infra/route/engine/impl/PartialSQLRouteExecutorTest.java
@@ -20,6 +20,7 @@ package org.apache.shardingsphere.infra.route.engine.impl;
import org.apache.shardingsphere.infra.binder.statement.CommonSQLStatementContext;
import org.apache.shardingsphere.infra.config.props.ConfigurationProperties;
import org.apache.shardingsphere.infra.hint.HintManager;
+import org.apache.shardingsphere.infra.hint.HintValueContext;
import org.apache.shardingsphere.infra.hint.SQLHintDataSourceNotExistsException;
import org.apache.shardingsphere.infra.metadata.database.ShardingSphereDatabase;
import org.apache.shardingsphere.infra.metadata.database.rule.ShardingSphereRuleMetaData;
@@ -57,6 +58,9 @@ class PartialSQLRouteExecutorTest {
@Mock
private CommonSQLStatementContext commonSQLStatementContext;
+ @Mock
+ private HintValueContext hintValueContext;
+
private final ConnectionContext connectionContext = new ConnectionContext();
@BeforeEach
@@ -69,8 +73,8 @@ class PartialSQLRouteExecutorTest {
@Test
void assertRouteBySQLCommentHint() {
- when(commonSQLStatementContext.findHintDataSourceName()).thenReturn(Optional.of("ds_1"));
- QueryContext queryContext = new QueryContext(commonSQLStatementContext, "", Collections.emptyList());
+ when(hintValueContext.findHintDataSourceName()).thenReturn(Optional.of("ds_1"));
+ QueryContext queryContext = new QueryContext(commonSQLStatementContext, "", Collections.emptyList(), hintValueContext);
RouteContext routeContext = partialSQLRouteExecutor.route(connectionContext, queryContext, mock(ShardingSphereRuleMetaData.class), database);
assertThat(routeContext.getRouteUnits().size(), is(1));
assertThat(routeContext.getRouteUnits().iterator().next().getDataSourceMapper().getActualName(), is("ds_1"));
@@ -89,8 +93,8 @@ class PartialSQLRouteExecutorTest {
@Test
void assertRouteBySQLCommentHintWithException() {
- when(commonSQLStatementContext.findHintDataSourceName()).thenReturn(Optional.of("ds_3"));
- QueryContext queryContext = new QueryContext(commonSQLStatementContext, "", Collections.emptyList());
+ when(hintValueContext.findHintDataSourceName()).thenReturn(Optional.of("ds_3"));
+ QueryContext queryContext = new QueryContext(commonSQLStatementContext, "", Collections.emptyList(), hintValueContext);
assertThrows(SQLHintDataSourceNotExistsException.class, () -> partialSQLRouteExecutor.route(connectionContext, queryContext, mock(ShardingSphereRuleMetaData.class), database));
}
diff --git a/infra/session/src/main/java/org/apache/shardingsphere/infra/session/query/QueryContext.java b/infra/session/src/main/java/org/apache/shardingsphere/infra/session/query/QueryContext.java
index 579ff63949c..c80311c50c8 100644
--- a/infra/session/src/main/java/org/apache/shardingsphere/infra/session/query/QueryContext.java
+++ b/infra/session/src/main/java/org/apache/shardingsphere/infra/session/query/QueryContext.java
@@ -22,6 +22,8 @@ import lombok.Getter;
import org.apache.shardingsphere.infra.binder.statement.SQLStatementContext;
import org.apache.shardingsphere.infra.binder.type.TableAvailable;
import org.apache.shardingsphere.infra.hint.HintValueContext;
+import org.apache.shardingsphere.infra.hint.SQLHintUtils;
+import org.apache.shardingsphere.sql.parser.sql.common.statement.AbstractSQLStatement;
import java.util.List;
import java.util.Optional;
@@ -58,7 +60,9 @@ public final class QueryContext {
this.sql = sql;
parameters = params;
databaseName = sqlStatementContext instanceof TableAvailable ? ((TableAvailable) sqlStatementContext).getTablesContext().getDatabaseName().orElse(null) : null;
- this.hintValueContext = hintValueContext;
+ this.hintValueContext = sqlStatementContext.getSqlStatement() instanceof AbstractSQLStatement && !((AbstractSQLStatement) sqlStatementContext.getSqlStatement()).getCommentSegments().isEmpty()
+ ? SQLHintUtils.extractHint(((AbstractSQLStatement) sqlStatementContext.getSqlStatement()).getCommentSegments().iterator().next().getText()).orElse(hintValueContext)
+ : hintValueContext;
this.useCache = useCache;
}
diff --git a/jdbc/core/src/main/java/org/apache/shardingsphere/driver/jdbc/core/statement/ShardingSpherePreparedStatement.java b/jdbc/core/src/main/java/org/apache/shardingsphere/driver/jdbc/core/statement/ShardingSpherePreparedStatement.java
index 4e91808158a..dfba1e03667 100644
--- a/jdbc/core/src/main/java/org/apache/shardingsphere/driver/jdbc/core/statement/ShardingSpherePreparedStatement.java
+++ b/jdbc/core/src/main/java/org/apache/shardingsphere/driver/jdbc/core/statement/ShardingSpherePreparedStatement.java
@@ -538,7 +538,7 @@ public final class ShardingSpherePreparedStatement extends AbstractPreparedState
private ExecutionContext createExecutionContext(final QueryContext queryContext) {
ShardingSphereRuleMetaData globalRuleMetaData = metaDataContexts.getMetaData().getGlobalRuleMetaData();
ShardingSphereDatabase currentDatabase = metaDataContexts.getMetaData().getDatabase(connection.getDatabaseName());
- SQLAuditEngine.audit(queryContext.getSqlStatementContext(), queryContext.getParameters(), globalRuleMetaData, currentDatabase, null);
+ SQLAuditEngine.audit(queryContext.getSqlStatementContext(), queryContext.getParameters(), globalRuleMetaData, currentDatabase, null, queryContext.getHintValueContext());
ExecutionContext result = kernelProcessor.generateExecutionContext(
queryContext, currentDatabase, globalRuleMetaData, metaDataContexts.getMetaData().getProps(), connection.getDatabaseConnectionManager().getConnectionContext());
findGeneratedKey(result).ifPresent(optional -> generatedValues.addAll(optional.getGeneratedValues()));
@@ -556,7 +556,7 @@ public final class ShardingSpherePreparedStatement extends AbstractPreparedState
((ParameterAware) sqlStatementContext).setUpParameters(params);
}
SQLParserRule sqlParserRule = metaDataContexts.getMetaData().getGlobalRuleMetaData().getSingleRule(SQLParserRule.class);
- HintValueContext hintValueContext = sqlParserRule.isSqlCommentParseEnabled() ? new HintValueContext() : SQLHintUtils.extractHint(sql);
+ HintValueContext hintValueContext = sqlParserRule.isSqlCommentParseEnabled() ? new HintValueContext() : SQLHintUtils.extractHint(sql).orElseGet(HintValueContext::new);
return new QueryContext(sqlStatementContext, sql, params, hintValueContext, true);
}
diff --git a/jdbc/core/src/main/java/org/apache/shardingsphere/driver/jdbc/core/statement/ShardingSphereStatement.java b/jdbc/core/src/main/java/org/apache/shardingsphere/driver/jdbc/core/statement/ShardingSphereStatement.java
index 9e90c986604..3fb6ff29f5b 100644
--- a/jdbc/core/src/main/java/org/apache/shardingsphere/driver/jdbc/core/statement/ShardingSphereStatement.java
+++ b/jdbc/core/src/main/java/org/apache/shardingsphere/driver/jdbc/core/statement/ShardingSphereStatement.java
@@ -498,7 +498,7 @@ public final class ShardingSphereStatement extends AbstractStatementAdapter {
SQLStatement sqlStatement = sqlParserRule.getSQLParserEngine(
DatabaseTypeEngine.getTrunkDatabaseTypeName(metaDataContexts.getMetaData().getDatabase(connection.getDatabaseName()).getProtocolType())).parse(sql, false);
SQLStatementContext sqlStatementContext = SQLStatementContextFactory.newInstance(metaDataContexts.getMetaData(), sqlStatement, connection.getDatabaseName());
- HintValueContext hintValueContext = sqlParserRule.isSqlCommentParseEnabled() ? new HintValueContext() : SQLHintUtils.extractHint(originSQL);
+ HintValueContext hintValueContext = sqlParserRule.isSqlCommentParseEnabled() ? new HintValueContext() : SQLHintUtils.extractHint(originSQL).orElseGet(HintValueContext::new);
return new QueryContext(sqlStatementContext, sql, Collections.emptyList(), hintValueContext);
}
@@ -506,7 +506,7 @@ public final class ShardingSphereStatement extends AbstractStatementAdapter {
clearStatements();
ShardingSphereRuleMetaData globalRuleMetaData = metaDataContexts.getMetaData().getGlobalRuleMetaData();
ShardingSphereDatabase currentDatabase = metaDataContexts.getMetaData().getDatabase(connection.getDatabaseName());
- SQLAuditEngine.audit(queryContext.getSqlStatementContext(), queryContext.getParameters(), globalRuleMetaData, currentDatabase, null);
+ SQLAuditEngine.audit(queryContext.getSqlStatementContext(), queryContext.getParameters(), globalRuleMetaData, currentDatabase, null, queryContext.getHintValueContext());
return kernelProcessor.generateExecutionContext(queryContext, currentDatabase, globalRuleMetaData, metaDataContexts.getMetaData().getProps(),
connection.getDatabaseConnectionManager().getConnectionContext());
}
diff --git a/kernel/traffic/core/src/main/java/org/apache/shardingsphere/traffic/rule/TrafficRule.java b/kernel/traffic/core/src/main/java/org/apache/shardingsphere/traffic/rule/TrafficRule.java
index d3eb1ffa54b..8eeef3bd77d 100644
--- a/kernel/traffic/core/src/main/java/org/apache/shardingsphere/traffic/rule/TrafficRule.java
+++ b/kernel/traffic/core/src/main/java/org/apache/shardingsphere/traffic/rule/TrafficRule.java
@@ -19,11 +19,10 @@ package org.apache.shardingsphere.traffic.rule;
import com.google.common.base.Preconditions;
import lombok.Getter;
-import org.apache.shardingsphere.infra.session.query.QueryContext;
-import org.apache.shardingsphere.infra.binder.statement.CommonSQLStatementContext;
import org.apache.shardingsphere.infra.config.algorithm.AlgorithmConfiguration;
import org.apache.shardingsphere.infra.hint.HintValueContext;
import org.apache.shardingsphere.infra.rule.identifier.scope.GlobalRule;
+import org.apache.shardingsphere.infra.session.query.QueryContext;
import org.apache.shardingsphere.infra.util.spi.type.typed.TypedSPILoader;
import org.apache.shardingsphere.sql.parser.sql.common.statement.SQLStatement;
import org.apache.shardingsphere.traffic.api.config.TrafficRuleConfiguration;
@@ -143,16 +142,12 @@ public final class TrafficRule implements GlobalRule {
return result;
}
- @SuppressWarnings("rawtypes")
private boolean match(final TrafficAlgorithm trafficAlgorithm, final QueryContext queryContext, final boolean inTransaction) {
if (trafficAlgorithm instanceof TransactionTrafficAlgorithm) {
return matchTransactionTraffic((TransactionTrafficAlgorithm) trafficAlgorithm, inTransaction);
}
if (trafficAlgorithm instanceof HintTrafficAlgorithm) {
- HintValueContext hintValueContext = queryContext.getSqlStatementContext() instanceof CommonSQLStatementContext
- ? ((CommonSQLStatementContext) queryContext.getSqlStatementContext()).getSqlHintExtractor().getHintValueContext()
- : new HintValueContext();
- return matchHintTraffic((HintTrafficAlgorithm) trafficAlgorithm, hintValueContext);
+ return matchHintTraffic((HintTrafficAlgorithm) trafficAlgorithm, queryContext.getHintValueContext());
}
if (trafficAlgorithm instanceof SegmentTrafficAlgorithm) {
SQLStatement sqlStatement = queryContext.getSqlStatementContext().getSqlStatement();
diff --git a/kernel/traffic/core/src/test/java/org/apache/shardingsphere/traffic/rule/TrafficRuleTest.java b/kernel/traffic/core/src/test/java/org/apache/shardingsphere/traffic/rule/TrafficRuleTest.java
index 66e2a2c4218..94f42dc369c 100644
--- a/kernel/traffic/core/src/test/java/org/apache/shardingsphere/traffic/rule/TrafficRuleTest.java
+++ b/kernel/traffic/core/src/test/java/org/apache/shardingsphere/traffic/rule/TrafficRuleTest.java
@@ -88,14 +88,12 @@ class TrafficRuleTest {
}
private QueryContext createQueryContext(final boolean includeComments) {
- QueryContext result = mock(QueryContext.class);
MySQLSelectStatement sqlStatement = mock(MySQLSelectStatement.class);
when(sqlStatement.getCommentSegments()).thenReturn(includeComments ? Collections.singleton(new CommentSegment("/* SHARDINGSPHERE_HINT: USE_TRAFFIC=true */", 0, 0)) : Collections.emptyList());
when(sqlStatement.getProjections()).thenReturn(new ProjectionsSegment(0, 0));
SQLStatementContext statementContext =
new SelectStatementContext(createShardingSphereMetaData(mockDatabase()), Collections.emptyList(), sqlStatement, DefaultDatabase.LOGIC_NAME);
- when(result.getSqlStatementContext()).thenReturn(statementContext);
- return result;
+ return new QueryContext(statementContext, "", Collections.emptyList());
}
private ShardingSphereDatabase mockDatabase() {
diff --git a/proxy/backend/core/src/main/java/org/apache/shardingsphere/proxy/backend/handler/ProxyBackendHandlerFactory.java b/proxy/backend/core/src/main/java/org/apache/shardingsphere/proxy/backend/handler/ProxyBackendHandlerFactory.java
index 756343bf2e1..7b74a4744c6 100644
--- a/proxy/backend/core/src/main/java/org/apache/shardingsphere/proxy/backend/handler/ProxyBackendHandlerFactory.java
+++ b/proxy/backend/core/src/main/java/org/apache/shardingsphere/proxy/backend/handler/ProxyBackendHandlerFactory.java
@@ -26,13 +26,13 @@ import org.apache.shardingsphere.distsql.parser.statement.DistSQLStatement;
import org.apache.shardingsphere.distsql.parser.statement.ral.QueryableRALStatement;
import org.apache.shardingsphere.distsql.parser.statement.rql.RQLStatement;
import org.apache.shardingsphere.distsql.parser.statement.rul.RULStatement;
-import org.apache.shardingsphere.infra.session.query.QueryContext;
import org.apache.shardingsphere.infra.binder.SQLStatementContextFactory;
import org.apache.shardingsphere.infra.binder.statement.SQLStatementContext;
import org.apache.shardingsphere.infra.database.type.DatabaseType;
import org.apache.shardingsphere.infra.executor.audit.SQLAuditEngine;
import org.apache.shardingsphere.infra.hint.HintValueContext;
import org.apache.shardingsphere.infra.metadata.database.ShardingSphereDatabase;
+import org.apache.shardingsphere.infra.session.query.QueryContext;
import org.apache.shardingsphere.infra.state.cluster.ClusterState;
import org.apache.shardingsphere.infra.util.exception.ShardingSpherePreconditions;
import org.apache.shardingsphere.infra.util.exception.external.sql.type.generic.UnsupportedSQLOperationException;
@@ -162,8 +162,8 @@ public final class ProxyBackendHandlerFactory {
AuthorityRule authorityRule = ProxyContext.getInstance().getContextManager().getMetaDataContexts().getMetaData().getGlobalRuleMetaData().getSingleRule(AuthorityRule.class);
ShardingSphereDatabase database = ProxyContext.getInstance().getContextManager().getMetaDataContexts().getMetaData().getDatabase(databaseName);
new AuthorityChecker(authorityRule, connectionSession.getGrantee()).checkPrivileges(databaseName, sqlStatementContext.getSqlStatement());
- SQLAuditEngine.audit(sqlStatementContext, queryContext.getParameters(),
- ProxyContext.getInstance().getContextManager().getMetaDataContexts().getMetaData().getGlobalRuleMetaData(), database, connectionSession.getGrantee());
+ SQLAuditEngine.audit(sqlStatementContext, queryContext.getParameters(), ProxyContext.getInstance().getContextManager().getMetaDataContexts().getMetaData().getGlobalRuleMetaData(),
+ database, connectionSession.getGrantee(), queryContext.getHintValueContext());
backendHandler = DatabaseAdminBackendHandlerFactory.newInstance(databaseType, sqlStatementContext, connectionSession);
return backendHandler.orElseGet(() -> DatabaseBackendHandlerFactory.newInstance(queryContext, connectionSession, preferPreparedStatement));
}
diff --git a/proxy/frontend/type/mysql/src/main/java/org/apache/shardingsphere/proxy/frontend/mysql/command/query/text/query/MySQLMultiStatementsHandler.java b/proxy/frontend/type/mysql/src/main/java/org/apache/shardingsphere/proxy/frontend/mysql/command/query/text/query/MySQLMultiStatementsHandler.java
index 8f810222a30..caf0aaa2601 100644
--- a/proxy/frontend/type/mysql/src/main/java/org/apache/shardingsphere/proxy/frontend/mysql/command/query/text/query/MySQLMultiStatementsHandler.java
+++ b/proxy/frontend/type/mysql/src/main/java/org/apache/shardingsphere/proxy/frontend/mysql/command/query/text/query/MySQLMultiStatementsHandler.java
@@ -17,7 +17,6 @@
package org.apache.shardingsphere.proxy.frontend.mysql.command.query.text.query;
-import org.apache.shardingsphere.infra.session.query.QueryContext;
import org.apache.shardingsphere.infra.binder.SQLStatementContextFactory;
import org.apache.shardingsphere.infra.binder.statement.SQLStatementContext;
import org.apache.shardingsphere.infra.config.props.ConfigurationPropertyKey;
@@ -42,6 +41,7 @@ import org.apache.shardingsphere.infra.metadata.database.ShardingSphereDatabase;
import org.apache.shardingsphere.infra.metadata.database.rule.ShardingSphereRuleMetaData;
import org.apache.shardingsphere.infra.parser.SQLParserEngine;
import org.apache.shardingsphere.infra.rule.ShardingSphereRule;
+import org.apache.shardingsphere.infra.session.query.QueryContext;
import org.apache.shardingsphere.infra.util.spi.type.typed.TypedSPILoader;
import org.apache.shardingsphere.mode.metadata.MetaDataContexts;
import org.apache.shardingsphere.parser.rule.SQLParserRule;
@@ -130,7 +130,7 @@ public final class MySQLMultiStatementsHandler implements ProxyBackendHandler {
private ExecutionContext createExecutionContext(final QueryContext queryContext) {
ShardingSphereRuleMetaData globalRuleMetaData = metaDataContexts.getMetaData().getGlobalRuleMetaData();
ShardingSphereDatabase currentDatabase = metaDataContexts.getMetaData().getDatabase(connectionSession.getDatabaseName());
- SQLAuditEngine.audit(queryContext.getSqlStatementContext(), queryContext.getParameters(), globalRuleMetaData, currentDatabase, null);
+ SQLAuditEngine.audit(queryContext.getSqlStatementContext(), queryContext.getParameters(), globalRuleMetaData, currentDatabase, null, queryContext.getHintValueContext());
return kernelProcessor.generateExecutionContext(queryContext, currentDatabase, globalRuleMetaData, metaDataContexts.getMetaData().getProps(), connectionSession.getConnectionContext());
}
diff --git a/proxy/frontend/type/postgresql/src/main/java/org/apache/shardingsphere/proxy/frontend/postgresql/command/query/extended/PostgreSQLBatchedStatementsExecutor.java b/proxy/frontend/type/postgresql/src/main/java/org/apache/shardingsphere/proxy/frontend/postgresql/command/query/extended/PostgreSQLBatchedStatementsExecutor.java
index f63bbd14286..d103937b18b 100644
--- a/proxy/frontend/type/postgresql/src/main/java/org/apache/shardingsphere/proxy/frontend/postgresql/command/query/extended/PostgreSQLBatchedStatementsExecutor.java
+++ b/proxy/frontend/type/postgresql/src/main/java/org/apache/shardingsphere/proxy/frontend/postgresql/command/query/extended/PostgreSQLBatchedStatementsExecutor.java
@@ -18,7 +18,6 @@
package org.apache.shardingsphere.proxy.frontend.postgresql.command.query.extended;
import org.apache.shardingsphere.db.protocol.postgresql.packet.command.query.extended.bind.PostgreSQLTypeUnspecifiedSQLParameter;
-import org.apache.shardingsphere.infra.session.query.QueryContext;
import org.apache.shardingsphere.infra.binder.SQLStatementContextFactory;
import org.apache.shardingsphere.infra.binder.aware.ParameterAware;
import org.apache.shardingsphere.infra.binder.statement.SQLStatementContext;
@@ -42,6 +41,7 @@ import org.apache.shardingsphere.infra.executor.sql.prepare.driver.jdbc.Statemen
import org.apache.shardingsphere.infra.metadata.database.ShardingSphereDatabase;
import org.apache.shardingsphere.infra.metadata.database.rule.ShardingSphereRuleMetaData;
import org.apache.shardingsphere.infra.rule.ShardingSphereRule;
+import org.apache.shardingsphere.infra.session.query.QueryContext;
import org.apache.shardingsphere.mode.metadata.MetaDataContexts;
import org.apache.shardingsphere.proxy.backend.connector.jdbc.statement.JDBCBackendStatement;
import org.apache.shardingsphere.proxy.backend.context.BackendExecutorContext;
@@ -129,7 +129,7 @@ public final class PostgreSQLBatchedStatementsExecutor {
private ExecutionContext createExecutionContext(final QueryContext queryContext) {
ShardingSphereRuleMetaData globalRuleMetaData = metaDataContexts.getMetaData().getGlobalRuleMetaData();
ShardingSphereDatabase currentDatabase = metaDataContexts.getMetaData().getDatabase(connectionSession.getDatabaseName());
- SQLAuditEngine.audit(queryContext.getSqlStatementContext(), queryContext.getParameters(), globalRuleMetaData, currentDatabase, null);
+ SQLAuditEngine.audit(queryContext.getSqlStatementContext(), queryContext.getParameters(), globalRuleMetaData, currentDatabase, null, queryContext.getHintValueContext());
return kernelProcessor.generateExecutionContext(queryContext, currentDatabase, globalRuleMetaData, metaDataContexts.getMetaData().getProps(), connectionSession.getConnectionContext());
}
diff --git a/test/it/rewriter/src/test/java/org/apache/shardingsphere/test/it/rewrite/engine/SQLRewriterIT.java b/test/it/rewriter/src/test/java/org/apache/shardingsphere/test/it/rewrite/engine/SQLRewriterIT.java
index 852323994d0..9f89bc80fde 100644
--- a/test/it/rewriter/src/test/java/org/apache/shardingsphere/test/it/rewrite/engine/SQLRewriterIT.java
+++ b/test/it/rewriter/src/test/java/org/apache/shardingsphere/test/it/rewrite/engine/SQLRewriterIT.java
@@ -145,7 +145,8 @@ public abstract class SQLRewriterIT {
SQLRewriteEntry sqlRewriteEntry = new SQLRewriteEntry(database, globalRuleMetaData, props);
ConnectionContext connectionContext = mock(ConnectionContext.class);
when(connectionContext.getCursorContext()).thenReturn(new CursorConnectionContext());
- SQLRewriteResult sqlRewriteResult = sqlRewriteEntry.rewrite(testParams.getInputSQL(), testParams.getInputParameters(), sqlStatementContext, routeContext, connectionContext);
+ SQLRewriteResult sqlRewriteResult = sqlRewriteEntry.rewrite(testParams.getInputSQL(), testParams.getInputParameters(), sqlStatementContext, routeContext, connectionContext,
+ queryContext.getHintValueContext());
return sqlRewriteResult instanceof GenericSQLRewriteResult
? Collections.singleton(((GenericSQLRewriteResult) sqlRewriteResult).getSqlRewriteUnit())
: (((RouteSQLRewriteResult) sqlRewriteResult).getSqlRewriteUnits()).values();