You are viewing a plain text version of this content. The canonical link for it is here.
Posted to notifications@shardingsphere.apache.org by pa...@apache.org on 2020/07/27 03:19:53 UTC
[shardingsphere] branch master updated: support mysql insert select
statement route and rewrite (#6438)
This is an automated email from the ASF dual-hosted git repository.
panjuan pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/shardingsphere.git
The following commit(s) were added to refs/heads/master by this push:
new e3ed615 support mysql insert select statement route and rewrite (#6438)
e3ed615 is described below
commit e3ed6150f23a0264606d2e68149ece0c88e957a8
Author: DuanZhengqiang <st...@gmail.com>
AuthorDate: Sun Jul 26 22:19:32 2020 -0500
support mysql insert select statement route and rewrite (#6438)
---
.../shardingsphere/sharding/rule/ShardingRule.java | 16 +++
.../src/test/resources/sharding/insert.xml | 152 +++++++++++++++++++++
.../route/engine/ShardingRouteDecorator.java | 18 ++-
.../InsertClauseShardingConditionEngine.java | 15 +-
.../standard/ShardingStandardRoutingEngine.java | 6 +-
.../validator/ShardingStatementValidator.java | 6 +-
.../impl/ShardingInsertStatementValidator.java | 33 ++++-
.../impl/ShardingUpdateStatementValidator.java | 7 +-
.../impl/ShardingInsertStatementValidatorTest.java | 63 ++++++++-
.../impl/ShardingUpdateStatementValidatorTest.java | 16 ++-
.../segment/insert/values/InsertSelectContext.java | 55 ++++++++
.../parser/binder/segment/table/TablesContext.java | 2 +
.../statement/dml/InsertStatementContext.java | 33 ++++-
.../parser/sql/statement/dml/InsertStatement.java | 3 +
14 files changed, 394 insertions(+), 31 deletions(-)
diff --git a/shardingsphere-features/shardingsphere-sharding/shardingsphere-sharding-common/src/main/java/org/apache/shardingsphere/sharding/rule/ShardingRule.java b/shardingsphere-features/shardingsphere-sharding/shardingsphere-sharding-common/src/main/java/org/apache/shardingsphere/sharding/rule/ShardingRule.java
index 5131a9c..5bd30b1 100644
--- a/shardingsphere-features/shardingsphere-sharding/shardingsphere-sharding-common/src/main/java/org/apache/shardingsphere/sharding/rule/ShardingRule.java
+++ b/shardingsphere-features/shardingsphere-sharding/shardingsphere-sharding-common/src/main/java/org/apache/shardingsphere/sharding/rule/ShardingRule.java
@@ -312,6 +312,22 @@ public final class ShardingRule implements DataNodeRoutedRule {
}
/**
+ * Judge is generate key column or not.
+ *
+ * @param columnName column name
+ * @param tableName table name
+ * @return is generate key column or not
+ */
+ public boolean isGenerateKeyColumn(final String columnName, final String tableName) {
+ return tableRules.stream().anyMatch(each -> each.getLogicTable().equalsIgnoreCase(tableName) && isGenerateKeyColumn(each, columnName));
+ }
+
+ private boolean isGenerateKeyColumn(final TableRule tableRule, final String columnName) {
+ Optional<String> generateKeyColumn = tableRule.getGenerateKeyColumn();
+ return generateKeyColumn.isPresent() && generateKeyColumn.get().equalsIgnoreCase(columnName);
+ }
+
+ /**
* Find column name of generated key.
*
* @param logicTableName logic table name
diff --git a/shardingsphere-features/shardingsphere-sharding/shardingsphere-sharding-rewrite/src/test/resources/sharding/insert.xml b/shardingsphere-features/shardingsphere-sharding/shardingsphere-sharding-rewrite/src/test/resources/sharding/insert.xml
index 948712f..e66a3d4 100644
--- a/shardingsphere-features/shardingsphere-sharding/shardingsphere-sharding-rewrite/src/test/resources/sharding/insert.xml
+++ b/shardingsphere-features/shardingsphere-sharding/shardingsphere-sharding-rewrite/src/test/resources/sharding/insert.xml
@@ -137,6 +137,158 @@
<output sql="INSERT INTO t_account_1 VALUES (101, 2000, 'OK') ON DUPLICATE KEY UPDATE status = ?" parameters="OK_UPDATE" />
</rewrite-assertion>
+ <rewrite-assertion id="insert_select_with_all_columns_with_sharding_column_for_parameters" db-type="MySQL">
+ <input sql="INSERT INTO t_account (account_id, amount, status) SELECT account_id, amount, status FROM t_account WHERE account_id = ? ON DUPLICATE KEY UPDATE amount = ?" parameters="100, 20" />
+ <output sql="INSERT INTO t_account_0 (account_id, amount, status) SELECT account_id, amount, status FROM t_account_0 WHERE account_id = ? ON DUPLICATE KEY UPDATE amount = ?" parameters="100, 20" />
+ </rewrite-assertion>
+
+ <rewrite-assertion id="insert_select_with_all_columns_without_sharding_column_for_parameters" db-type="MySQL">
+ <input sql="INSERT INTO t_account (account_id, amount, status) SELECT account_id, amount, status FROM t_account WHERE amount = ? ON DUPLICATE KEY UPDATE amount = ?" parameters="1000, 20" />
+ <output sql="INSERT INTO t_account_0 (account_id, amount, status) SELECT account_id, amount, status FROM t_account_0 WHERE amount = ? ON DUPLICATE KEY UPDATE amount = ?" parameters="1000, 20" />
+ <output sql="INSERT INTO t_account_1 (account_id, amount, status) SELECT account_id, amount, status FROM t_account_1 WHERE amount = ? ON DUPLICATE KEY UPDATE amount = ?" parameters="1000, 20" />
+ </rewrite-assertion>
+
+ <rewrite-assertion id="insert_select_with_all_columns_with_sharding_column_for_literals" db-type="MySQL">
+ <input sql="INSERT INTO t_account (account_id, amount, status) SELECT account_id, amount, status FROM t_account WHERE account_id = 100 ON DUPLICATE KEY UPDATE amount = 20" />
+ <output sql="INSERT INTO t_account_0 (account_id, amount, status) SELECT account_id, amount, status FROM t_account_0 WHERE account_id = 100 ON DUPLICATE KEY UPDATE amount = 20" />
+ </rewrite-assertion>
+
+ <rewrite-assertion id="insert_select_with_all_columns_without_sharding_column_for_literals" db-type="MySQL">
+ <input sql="INSERT INTO t_account (account_id, amount, status) SELECT account_id, amount, status FROM t_account WHERE amount = 1000 ON DUPLICATE KEY UPDATE amount = 20" />
+ <output sql="INSERT INTO t_account_0 (account_id, amount, status) SELECT account_id, amount, status FROM t_account_0 WHERE amount = 1000 ON DUPLICATE KEY UPDATE amount = 20" />
+ <output sql="INSERT INTO t_account_1 (account_id, amount, status) SELECT account_id, amount, status FROM t_account_1 WHERE amount = 1000 ON DUPLICATE KEY UPDATE amount = 20" />
+ </rewrite-assertion>
+
+ <rewrite-assertion id="insert_select_without_columns_with_sharding_column_for_parameters" db-type="MySQL">
+ <input sql="INSERT INTO t_account SELECT account_id, amount, status FROM t_account WHERE account_id = ? ON DUPLICATE KEY UPDATE amount = ?" parameters="100, 20" />
+ <output sql="INSERT INTO t_account_0 SELECT account_id, amount, status FROM t_account_0 WHERE account_id = ? ON DUPLICATE KEY UPDATE amount = ?" parameters="100, 20" />
+ </rewrite-assertion>
+
+ <rewrite-assertion id="insert_select_without_columns_without_sharding_column_for_parameters" db-type="MySQL">
+ <input sql="INSERT INTO t_account SELECT account_id, amount, status FROM t_account WHERE amount = ? ON DUPLICATE KEY UPDATE amount = ?" parameters="1000, 20" />
+ <output sql="INSERT INTO t_account_0 SELECT account_id, amount, status FROM t_account_0 WHERE amount = ? ON DUPLICATE KEY UPDATE amount = ?" parameters="1000, 20" />
+ <output sql="INSERT INTO t_account_1 SELECT account_id, amount, status FROM t_account_1 WHERE amount = ? ON DUPLICATE KEY UPDATE amount = ?" parameters="1000, 20" />
+ </rewrite-assertion>
+
+ <rewrite-assertion id="insert_select_without_columns_with_sharding_column_for_literals" db-type="MySQL">
+ <input sql="INSERT INTO t_account SELECT account_id, amount, status FROM t_account WHERE account_id = 100 ON DUPLICATE KEY UPDATE amount = 20" />
+ <output sql="INSERT INTO t_account_0 SELECT account_id, amount, status FROM t_account_0 WHERE account_id = 100 ON DUPLICATE KEY UPDATE amount = 20" />
+ </rewrite-assertion>
+
+ <rewrite-assertion id="insert_select_without_columns_without_sharding_column_for_literals" db-type="MySQL">
+ <input sql="INSERT INTO t_account SELECT account_id, amount, status FROM t_account WHERE amount = 1000 ON DUPLICATE KEY UPDATE amount = 20" />
+ <output sql="INSERT INTO t_account_0 SELECT account_id, amount, status FROM t_account_0 WHERE amount = 1000 ON DUPLICATE KEY UPDATE amount = 20" />
+ <output sql="INSERT INTO t_account_1 SELECT account_id, amount, status FROM t_account_1 WHERE amount = 1000 ON DUPLICATE KEY UPDATE amount = 20" />
+ </rewrite-assertion>
+
+ <rewrite-assertion id="insert_select_with_binding_table_with_all_columns_with_sharding_column_for_parameters" db-type="MySQL">
+ <input sql="INSERT INTO t_account (account_id, amount, status) SELECT account_id, amount, status FROM t_account_detail WHERE account_id = ? ON DUPLICATE KEY UPDATE amount = ?" parameters="100, 20" />
+ <output sql="INSERT INTO t_account_0 (account_id, amount, status) SELECT account_id, amount, status FROM t_account_detail_0 WHERE account_id = ? ON DUPLICATE KEY UPDATE amount = ?" parameters="100, 20" />
+ </rewrite-assertion>
+
+ <rewrite-assertion id="insert_select_with_binding_table_with_all_columns_without_sharding_column_for_parameters" db-type="MySQL">
+ <input sql="INSERT INTO t_account (account_id, amount, status) SELECT account_id, amount, status FROM t_account_detail WHERE amount = ? ON DUPLICATE KEY UPDATE amount = ?" parameters="1000, 20" />
+ <output sql="INSERT INTO t_account_0 (account_id, amount, status) SELECT account_id, amount, status FROM t_account_detail_0 WHERE amount = ? ON DUPLICATE KEY UPDATE amount = ?" parameters="1000, 20" />
+ <output sql="INSERT INTO t_account_1 (account_id, amount, status) SELECT account_id, amount, status FROM t_account_detail_1 WHERE amount = ? ON DUPLICATE KEY UPDATE amount = ?" parameters="1000, 20" />
+ </rewrite-assertion>
+
+ <rewrite-assertion id="insert_select_with_binding_table_with_all_columns_with_sharding_column_for_literals" db-type="MySQL">
+ <input sql="INSERT INTO t_account (account_id, amount, status) SELECT account_id, amount, status FROM t_account_detail WHERE account_id = 100 ON DUPLICATE KEY UPDATE amount = VALUES(amount)" />
+ <output sql="INSERT INTO t_account_0 (account_id, amount, status) SELECT account_id, amount, status FROM t_account_detail_0 WHERE account_id = 100 ON DUPLICATE KEY UPDATE amount = VALUES(amount)" />
+ </rewrite-assertion>
+
+ <rewrite-assertion id="insert_select_with_binding_table_with_all_columns_without_sharding_column_for_literals" db-type="MySQL">
+ <input sql="INSERT INTO t_account (account_id, amount, status) SELECT account_id, amount, status FROM t_account_detail WHERE amount = 1000 ON DUPLICATE KEY UPDATE amount = VALUES(amount)" />
+ <output sql="INSERT INTO t_account_0 (account_id, amount, status) SELECT account_id, amount, status FROM t_account_detail_0 WHERE amount = 1000 ON DUPLICATE KEY UPDATE amount = VALUES(amount)" />
+ <output sql="INSERT INTO t_account_1 (account_id, amount, status) SELECT account_id, amount, status FROM t_account_detail_1 WHERE amount = 1000 ON DUPLICATE KEY UPDATE amount = VALUES(amount)" />
+ </rewrite-assertion>
+
+ <rewrite-assertion id="insert_select_with_binding_table_without_columns_with_sharding_column_for_parameters" db-type="MySQL">
+ <input sql="INSERT INTO t_account SELECT account_id, amount, status FROM t_account_detail WHERE account_id = ? ON DUPLICATE KEY UPDATE amount = ?" parameters="100, 20" />
+ <output sql="INSERT INTO t_account_0 SELECT account_id, amount, status FROM t_account_detail_0 WHERE account_id = ? ON DUPLICATE KEY UPDATE amount = ?" parameters="100, 20" />
+ </rewrite-assertion>
+
+ <rewrite-assertion id="insert_select_with_binding_table_without_columns_without_sharding_column_for_parameters" db-type="MySQL">
+ <input sql="INSERT INTO t_account SELECT account_id, amount, status FROM t_account_detail WHERE amount = ? ON DUPLICATE KEY UPDATE amount = ?" parameters="1000, 20" />
+ <output sql="INSERT INTO t_account_0 SELECT account_id, amount, status FROM t_account_detail_0 WHERE amount = ? ON DUPLICATE KEY UPDATE amount = ?" parameters="1000, 20" />
+ <output sql="INSERT INTO t_account_1 SELECT account_id, amount, status FROM t_account_detail_1 WHERE amount = ? ON DUPLICATE KEY UPDATE amount = ?" parameters="1000, 20" />
+ </rewrite-assertion>
+
+ <rewrite-assertion id="insert_select_with_binding_table_without_columns_with_sharding_column_for_literals" db-type="MySQL">
+ <input sql="INSERT INTO t_account SELECT account_id, amount, status FROM t_account_detail WHERE account_id = 100 ON DUPLICATE KEY UPDATE amount = 20" />
+ <output sql="INSERT INTO t_account_0 SELECT account_id, amount, status FROM t_account_detail_0 WHERE account_id = 100 ON DUPLICATE KEY UPDATE amount = 20" />
+ </rewrite-assertion>
+
+ <rewrite-assertion id="insert_select_with_binding_table_without_columns_without_sharding_column_for_literals" db-type="MySQL">
+ <input sql="INSERT INTO t_account SELECT account_id, amount, status FROM t_account_detail WHERE amount = 1000 ON DUPLICATE KEY UPDATE amount = 20" />
+ <output sql="INSERT INTO t_account_0 SELECT account_id, amount, status FROM t_account_detail_0 WHERE amount = 1000 ON DUPLICATE KEY UPDATE amount = 20" />
+ <output sql="INSERT INTO t_account_1 SELECT account_id, amount, status FROM t_account_detail_1 WHERE amount = 1000 ON DUPLICATE KEY UPDATE amount = 20" />
+ </rewrite-assertion>
+
+ <rewrite-assertion id="insert_select_with_subquery_with_all_columns_with_sharding_column_for_parameters" db-type="MySQL">
+ <input sql="INSERT INTO t_account (account_id, amount, status) SELECT t.account_id, t.amount, t.status FROM (SELECT account_id, amount, status FROM t_account WHERE account_id = ?) t WHERE t.account_id = ?" parameters="100, 100" />
+ <output sql="INSERT INTO t_account_0 (account_id, amount, status) SELECT t.account_id, t.amount, t.status FROM (SELECT account_id, amount, status FROM t_account_0 WHERE account_id = ?) t WHERE t.account_id = ?" parameters="100, 100" />
+ </rewrite-assertion>
+
+ <rewrite-assertion id="insert_select_with_subquery_with_all_columns_with_sharding_column_for_literals" db-type="MySQL">
+ <input sql="INSERT INTO t_account (account_id, amount, status) SELECT t.account_id, t.amount, t.status FROM (SELECT account_id, amount, status FROM t_account WHERE account_id = 100) t WHERE t.account_id = 100" />
+ <output sql="INSERT INTO t_account_0 (account_id, amount, status) SELECT t.account_id, t.amount, t.status FROM (SELECT account_id, amount, status FROM t_account_0 WHERE account_id = 100) t WHERE t.account_id = 100" />
+ </rewrite-assertion>
+
+ <rewrite-assertion id="insert_select_with_subquery_without_columns_with_sharding_column_for_parameters" db-type="MySQL">
+ <input sql="INSERT INTO t_account SELECT t.account_id, t.amount, t.status FROM (SELECT account_id, amount, status FROM t_account WHERE account_id = ?) t WHERE t.account_id = ?" parameters="100, 100" />
+ <output sql="INSERT INTO t_account_0 SELECT t.account_id, t.amount, t.status FROM (SELECT account_id, amount, status FROM t_account_0 WHERE account_id = ?) t WHERE t.account_id = ?" parameters="100, 100" />
+ </rewrite-assertion>
+
+ <rewrite-assertion id="insert_select_with_subquery_without_columns_with_sharding_column_for_literals" db-type="MySQL">
+ <input sql="INSERT INTO t_account SELECT t.account_id, t.amount, t.status FROM (SELECT account_id, amount, status FROM t_account WHERE account_id = 100) t WHERE t.account_id = 100" />
+ <output sql="INSERT INTO t_account_0 SELECT t.account_id, t.amount, t.status FROM (SELECT account_id, amount, status FROM t_account_0 WHERE account_id = 100) t WHERE t.account_id = 100" />
+ </rewrite-assertion>
+
+ <rewrite-assertion id="insert_select_with_pagination_with_all_columns_with_sharding_column_for_parameters" db-type="MySQL">
+ <input sql="INSERT INTO t_account (account_id, amount, status) SELECT account_id, amount, status FROM t_account WHERE account_id = ? LIMIT ?, ?" parameters="100, 1, 2" />
+ <output sql="INSERT INTO t_account_0 (account_id, amount, status) SELECT account_id, amount, status FROM t_account_0 WHERE account_id = ? LIMIT ?, ?" parameters="100, 1, 2" />
+ </rewrite-assertion>
+
+ <rewrite-assertion id="insert_select_with_pagination_with_all_columns_without_sharding_column_for_parameters" db-type="MySQL">
+ <input sql="INSERT INTO t_account (account_id, amount, status) SELECT account_id, amount, status FROM t_account WHERE amount = ? LIMIT ?, ?" parameters="100, 1, 2" />
+ <output sql="INSERT INTO t_account_0 (account_id, amount, status) SELECT account_id, amount, status FROM t_account_0 WHERE amount = ? LIMIT ?, ?" parameters="100, 1, 2" />
+ <output sql="INSERT INTO t_account_1 (account_id, amount, status) SELECT account_id, amount, status FROM t_account_1 WHERE amount = ? LIMIT ?, ?" parameters="100, 1, 2" />
+ </rewrite-assertion>
+
+ <rewrite-assertion id="insert_select_with_pagination_with_all_columns_with_sharding_column_for_literals" db-type="MySQL">
+ <input sql="INSERT INTO t_account (account_id, amount, status) SELECT account_id, amount, status FROM t_account WHERE account_id = 100 LIMIT 1, 2" />
+ <output sql="INSERT INTO t_account_0 (account_id, amount, status) SELECT account_id, amount, status FROM t_account_0 WHERE account_id = 100 LIMIT 1, 2" />
+ </rewrite-assertion>
+
+ <rewrite-assertion id="insert_select_with_pagination_with_all_columns_without_sharding_column_for_literals" db-type="MySQL">
+ <input sql="INSERT INTO t_account (account_id, amount, status) SELECT account_id, amount, status FROM t_account WHERE amount = 100 LIMIT 1, 2" />
+ <output sql="INSERT INTO t_account_0 (account_id, amount, status) SELECT account_id, amount, status FROM t_account_0 WHERE amount = 100 LIMIT 1, 2" />
+ <output sql="INSERT INTO t_account_1 (account_id, amount, status) SELECT account_id, amount, status FROM t_account_1 WHERE amount = 100 LIMIT 1, 2" />
+ </rewrite-assertion>
+
+ <rewrite-assertion id="insert_select_with_pagination_without_columns_with_sharding_column_for_parameters" db-type="MySQL">
+ <input sql="INSERT INTO t_account SELECT account_id, amount, status FROM t_account WHERE account_id = ? LIMIT ?, ?" parameters="100, 1, 2" />
+ <output sql="INSERT INTO t_account_0 SELECT account_id, amount, status FROM t_account_0 WHERE account_id = ? LIMIT ?, ?" parameters="100, 1, 2" />
+ </rewrite-assertion>
+
+ <rewrite-assertion id="insert_select_with_pagination_without_columns_without_sharding_column_for_parameters" db-type="MySQL">
+ <input sql="INSERT INTO t_account SELECT account_id, amount, status FROM t_account WHERE amount = ? LIMIT ?, ?" parameters="100, 1, 2" />
+ <output sql="INSERT INTO t_account_0 SELECT account_id, amount, status FROM t_account_0 WHERE amount = ? LIMIT ?, ?" parameters="100, 1, 2" />
+ <output sql="INSERT INTO t_account_1 SELECT account_id, amount, status FROM t_account_1 WHERE amount = ? LIMIT ?, ?" parameters="100, 1, 2" />
+ </rewrite-assertion>
+
+ <rewrite-assertion id="insert_select_with_pagination_without_columns_with_sharding_column_for_literals" db-type="MySQL">
+ <input sql="INSERT INTO t_account SELECT account_id, amount, status FROM t_account WHERE account_id = 100 LIMIT 1, 2" />
+ <output sql="INSERT INTO t_account_0 SELECT account_id, amount, status FROM t_account_0 WHERE account_id = 100 LIMIT 1, 2" />
+ </rewrite-assertion>
+
+ <rewrite-assertion id="insert_select_with_pagination_without_columns_without_sharding_column_for_literals" db-type="MySQL">
+ <input sql="INSERT INTO t_account SELECT account_id, amount, status FROM t_account WHERE amount = 100 LIMIT 1, 2" />
+ <output sql="INSERT INTO t_account_0 SELECT account_id, amount, status FROM t_account_0 WHERE amount = 100 LIMIT 1, 2" />
+ <output sql="INSERT INTO t_account_1 SELECT account_id, amount, status FROM t_account_1 WHERE amount = 100 LIMIT 1, 2" />
+ </rewrite-assertion>
+
<rewrite-assertion id="replace_values_with_columns_with_id_for_parameters" db-type="MySQL">
<input sql="REPLACE INTO t_account (account_id, amount, status) VALUES (?, ?, ?)" parameters="100, 1000, OK" />
<output sql="REPLACE INTO t_account_0 (account_id, amount, status) VALUES (?, ?, ?)" parameters="100, 1000, OK" />
diff --git a/shardingsphere-features/shardingsphere-sharding/shardingsphere-sharding-route/src/main/java/org/apache/shardingsphere/sharding/route/engine/ShardingRouteDecorator.java b/shardingsphere-features/shardingsphere-sharding/shardingsphere-sharding-route/src/main/java/org/apache/shardingsphere/sharding/route/engine/ShardingRouteDecorator.java
index a4ceceb..8cef1b1 100644
--- a/shardingsphere-features/shardingsphere-sharding/shardingsphere-sharding-route/src/main/java/org/apache/shardingsphere/sharding/route/engine/ShardingRouteDecorator.java
+++ b/shardingsphere-features/shardingsphere-sharding/shardingsphere-sharding-route/src/main/java/org/apache/shardingsphere/sharding/route/engine/ShardingRouteDecorator.java
@@ -42,6 +42,7 @@ import org.apache.shardingsphere.sql.parser.binder.metadata.schema.SchemaMetaDat
import org.apache.shardingsphere.sql.parser.binder.statement.SQLStatementContext;
import org.apache.shardingsphere.sql.parser.binder.statement.dml.InsertStatementContext;
import org.apache.shardingsphere.sql.parser.binder.statement.dml.SelectStatementContext;
+import org.apache.shardingsphere.sql.parser.sql.statement.SQLStatement;
import org.apache.shardingsphere.sql.parser.sql.statement.dml.DMLStatement;
import java.util.Collections;
@@ -58,11 +59,12 @@ public final class ShardingRouteDecorator implements RouteDecorator<ShardingRule
public RouteContext decorate(final RouteContext routeContext, final ShardingSphereMetaData metaData, final ShardingRule shardingRule, final ConfigurationProperties props) {
SQLStatementContext sqlStatementContext = routeContext.getSqlStatementContext();
List<Object> parameters = routeContext.getParameters();
+ SQLStatement sqlStatement = sqlStatementContext.getSqlStatement();
ShardingStatementValidatorFactory.newInstance(
- sqlStatementContext.getSqlStatement()).ifPresent(validator -> validator.validate(shardingRule, sqlStatementContext.getSqlStatement(), parameters));
+ sqlStatement).ifPresent(validator -> validator.validate(shardingRule, sqlStatement, sqlStatementContext.getTablesContext(), parameters));
ShardingConditions shardingConditions = getShardingConditions(parameters, sqlStatementContext, metaData.getSchema().getConfiguredSchemaMetaData(), shardingRule);
boolean needMergeShardingValues = isNeedMergeShardingValues(sqlStatementContext, shardingRule);
- if (sqlStatementContext.getSqlStatement() instanceof DMLStatement && needMergeShardingValues) {
+ if (sqlStatement instanceof DMLStatement && needMergeShardingValues) {
checkSubqueryShardingValues(sqlStatementContext, shardingRule, shardingConditions);
mergeShardingConditions(shardingConditions);
}
@@ -71,11 +73,11 @@ public final class ShardingRouteDecorator implements RouteDecorator<ShardingRule
return new RouteContext(sqlStatementContext, parameters, routeResult);
}
- private ShardingConditions getShardingConditions(final List<Object> parameters,
- final SQLStatementContext sqlStatementContext, final SchemaMetaData schemaMetaData, final ShardingRule shardingRule) {
+ private ShardingConditions getShardingConditions(final List<Object> parameters, final SQLStatementContext sqlStatementContext,
+ final SchemaMetaData schemaMetaData, final ShardingRule shardingRule) {
if (sqlStatementContext.getSqlStatement() instanceof DMLStatement) {
if (sqlStatementContext instanceof InsertStatementContext) {
- return new ShardingConditions(new InsertClauseShardingConditionEngine(shardingRule).createShardingConditions((InsertStatementContext) sqlStatementContext, parameters));
+ return new ShardingConditions(new InsertClauseShardingConditionEngine(shardingRule, schemaMetaData).createShardingConditions((InsertStatementContext) sqlStatementContext, parameters));
}
return new ShardingConditions(new WhereClauseShardingConditionEngine(shardingRule, schemaMetaData).createShardingConditions(sqlStatementContext, parameters));
}
@@ -83,8 +85,10 @@ public final class ShardingRouteDecorator implements RouteDecorator<ShardingRule
}
private boolean isNeedMergeShardingValues(final SQLStatementContext sqlStatementContext, final ShardingRule shardingRule) {
- return sqlStatementContext instanceof SelectStatementContext && ((SelectStatementContext) sqlStatementContext).isContainsSubquery()
- && !shardingRule.getShardingLogicTableNames(sqlStatementContext.getTablesContext().getTableNames()).isEmpty();
+ boolean selectContainsSubquery = sqlStatementContext instanceof SelectStatementContext && ((SelectStatementContext) sqlStatementContext).isContainsSubquery();
+ boolean insertSelectContainsSubquery = sqlStatementContext instanceof InsertStatementContext && null != ((InsertStatementContext) sqlStatementContext).getInsertSelectContext()
+ && ((InsertStatementContext) sqlStatementContext).getInsertSelectContext().getSelectStatementContext().isContainsSubquery();
+ return (selectContainsSubquery || insertSelectContainsSubquery) && !shardingRule.getShardingLogicTableNames(sqlStatementContext.getTablesContext().getTableNames()).isEmpty();
}
private void checkSubqueryShardingValues(final SQLStatementContext sqlStatementContext, final ShardingRule shardingRule, final ShardingConditions shardingConditions) {
diff --git a/shardingsphere-features/shardingsphere-sharding/shardingsphere-sharding-route/src/main/java/org/apache/shardingsphere/sharding/route/engine/condition/engine/InsertClauseShardingConditionEngine.java b/shardingsphere-features/shardingsphere-sharding/shardingsphere-sharding-route/src/main/java/org/apache/shardingsphere/sharding/route/engine/condition/engine/InsertClauseShardingConditionEngine.java
index 579747d..9469af4 100644
--- a/shardingsphere-features/shardingsphere-sharding/shardingsphere-sharding-route/src/main/java/org/apache/shardingsphere/sharding/route/engine/condition/engine/InsertClauseShardingConditionEngine.java
+++ b/shardingsphere-features/shardingsphere-sharding/shardingsphere-sharding-route/src/main/java/org/apache/shardingsphere/sharding/route/engine/condition/engine/InsertClauseShardingConditionEngine.java
@@ -19,19 +19,21 @@ package org.apache.shardingsphere.sharding.route.engine.condition.engine;
import com.google.common.base.Preconditions;
import lombok.RequiredArgsConstructor;
-import org.apache.shardingsphere.sharding.rule.ShardingRule;
-import org.apache.shardingsphere.sharding.strategy.value.ListRouteValue;
+import org.apache.shardingsphere.infra.exception.ShardingSphereException;
import org.apache.shardingsphere.sharding.route.engine.condition.ExpressionConditionUtils;
import org.apache.shardingsphere.sharding.route.engine.condition.ShardingCondition;
import org.apache.shardingsphere.sharding.route.spi.SPITimeService;
+import org.apache.shardingsphere.sharding.rule.ShardingRule;
+import org.apache.shardingsphere.sharding.strategy.value.ListRouteValue;
+import org.apache.shardingsphere.sql.parser.binder.metadata.schema.SchemaMetaData;
import org.apache.shardingsphere.sql.parser.binder.segment.insert.keygen.GeneratedKeyContext;
import org.apache.shardingsphere.sql.parser.binder.segment.insert.values.InsertValueContext;
import org.apache.shardingsphere.sql.parser.binder.statement.dml.InsertStatementContext;
+import org.apache.shardingsphere.sql.parser.binder.statement.dml.SelectStatementContext;
import org.apache.shardingsphere.sql.parser.sql.segment.dml.expr.ExpressionSegment;
import org.apache.shardingsphere.sql.parser.sql.segment.dml.expr.simple.LiteralExpressionSegment;
import org.apache.shardingsphere.sql.parser.sql.segment.dml.expr.simple.ParameterMarkerExpressionSegment;
import org.apache.shardingsphere.sql.parser.sql.segment.dml.expr.simple.SimpleExpressionSegment;
-import org.apache.shardingsphere.infra.exception.ShardingSphereException;
import java.util.Collection;
import java.util.Collections;
@@ -50,6 +52,8 @@ public final class InsertClauseShardingConditionEngine {
private final ShardingRule shardingRule;
+ private final SchemaMetaData schemaMetaData;
+
/**
* Create sharding conditions.
*
@@ -64,6 +68,11 @@ public final class InsertClauseShardingConditionEngine {
for (InsertValueContext each : insertStatementContext.getInsertValueContexts()) {
result.add(createShardingCondition(tableName, columnNames.iterator(), each, parameters));
}
+ if (null != insertStatementContext.getInsertSelectContext()) {
+ SelectStatementContext selectStatementContext = insertStatementContext.getInsertSelectContext().getSelectStatementContext();
+ List<ShardingCondition> shardingConditions = new WhereClauseShardingConditionEngine(shardingRule, schemaMetaData).createShardingConditions(selectStatementContext, parameters);
+ result.addAll(shardingConditions);
+ }
Optional<GeneratedKeyContext> generatedKey = insertStatementContext.getGeneratedKeyContext();
if (generatedKey.isPresent() && generatedKey.get().isGenerated()) {
generatedKey.get().getGeneratedValues().addAll(getGeneratedKeys(tableName, insertStatementContext.getSqlStatement().getValueListCount()));
diff --git a/shardingsphere-features/shardingsphere-sharding/shardingsphere-sharding-route/src/main/java/org/apache/shardingsphere/sharding/route/engine/type/standard/ShardingStandardRoutingEngine.java b/shardingsphere-features/shardingsphere-sharding/shardingsphere-sharding-route/src/main/java/org/apache/shardingsphere/sharding/route/engine/type/standard/ShardingStandardRoutingEngine.java
index fbf7bb9..135f761 100644
--- a/shardingsphere-features/shardingsphere-sharding/shardingsphere-sharding-route/src/main/java/org/apache/shardingsphere/sharding/route/engine/type/standard/ShardingStandardRoutingEngine.java
+++ b/shardingsphere-features/shardingsphere-sharding/shardingsphere-sharding-route/src/main/java/org/apache/shardingsphere/sharding/route/engine/type/standard/ShardingStandardRoutingEngine.java
@@ -68,7 +68,7 @@ public final class ShardingStandardRoutingEngine implements ShardingRouteEngine
@Override
public RouteResult route(final ShardingRule shardingRule) {
- if (isDMLForModify(sqlStatementContext) && 1 != ((TableAvailable) sqlStatementContext).getAllTables().size()) {
+ if (isDMLForModify(sqlStatementContext) && !containsInsertSelect(sqlStatementContext) && 1 != ((TableAvailable) sqlStatementContext).getAllTables().size()) {
throw new ShardingSphereException("Cannot support Multiple-Table for '%s'.", sqlStatementContext.getSqlStatement());
}
return generateRouteResult(getDataNodes(shardingRule, shardingRule.getTableRule(logicTableName)));
@@ -78,6 +78,10 @@ public final class ShardingStandardRoutingEngine implements ShardingRouteEngine
return sqlStatementContext instanceof InsertStatementContext || sqlStatementContext instanceof UpdateStatementContext || sqlStatementContext instanceof DeleteStatementContext;
}
+ private boolean containsInsertSelect(final SQLStatementContext sqlStatementContext) {
+ return sqlStatementContext instanceof InsertStatementContext && null != ((InsertStatementContext) sqlStatementContext).getInsertSelectContext();
+ }
+
private RouteResult generateRouteResult(final Collection<DataNode> routedDataNodes) {
RouteResult result = new RouteResult();
result.getOriginalDataNodes().addAll(originalDataNodes);
diff --git a/shardingsphere-features/shardingsphere-sharding/shardingsphere-sharding-route/src/main/java/org/apache/shardingsphere/sharding/route/engine/validator/ShardingStatementValidator.java b/shardingsphere-features/shardingsphere-sharding/shardingsphere-sharding-route/src/main/java/org/apache/shardingsphere/sharding/route/engine/validator/ShardingStatementValidator.java
index 62dd9d6..e4cd5a6 100644
--- a/shardingsphere-features/shardingsphere-sharding/shardingsphere-sharding-route/src/main/java/org/apache/shardingsphere/sharding/route/engine/validator/ShardingStatementValidator.java
+++ b/shardingsphere-features/shardingsphere-sharding/shardingsphere-sharding-route/src/main/java/org/apache/shardingsphere/sharding/route/engine/validator/ShardingStatementValidator.java
@@ -17,8 +17,9 @@
package org.apache.shardingsphere.sharding.route.engine.validator;
-import org.apache.shardingsphere.sql.parser.sql.statement.SQLStatement;
import org.apache.shardingsphere.sharding.rule.ShardingRule;
+import org.apache.shardingsphere.sql.parser.binder.segment.table.TablesContext;
+import org.apache.shardingsphere.sql.parser.sql.statement.SQLStatement;
import java.util.List;
@@ -34,7 +35,8 @@ public interface ShardingStatementValidator<T extends SQLStatement> {
*
* @param shardingRule sharding rule
* @param sqlStatement SQL statement
+ * @param tablesContext table context
* @param parameters SQL parameters
*/
- void validate(ShardingRule shardingRule, T sqlStatement, List<Object> parameters);
+ void validate(ShardingRule shardingRule, T sqlStatement, TablesContext tablesContext, List<Object> parameters);
}
diff --git a/shardingsphere-features/shardingsphere-sharding/shardingsphere-sharding-route/src/main/java/org/apache/shardingsphere/sharding/route/engine/validator/impl/ShardingInsertStatementValidator.java b/shardingsphere-features/shardingsphere-sharding/shardingsphere-sharding-route/src/main/java/org/apache/shardingsphere/sharding/route/engine/validator/impl/ShardingInsertStatementValidator.java
index ef5cf4c..1712ecd 100644
--- a/shardingsphere-features/shardingsphere-sharding/shardingsphere-sharding-route/src/main/java/org/apache/shardingsphere/sharding/route/engine/validator/impl/ShardingInsertStatementValidator.java
+++ b/shardingsphere-features/shardingsphere-sharding/shardingsphere-sharding-route/src/main/java/org/apache/shardingsphere/sharding/route/engine/validator/impl/ShardingInsertStatementValidator.java
@@ -17,13 +17,17 @@
package org.apache.shardingsphere.sharding.route.engine.validator.impl;
-import org.apache.shardingsphere.sharding.rule.ShardingRule;
+import org.apache.shardingsphere.infra.exception.ShardingSphereException;
import org.apache.shardingsphere.sharding.route.engine.validator.ShardingStatementValidator;
+import org.apache.shardingsphere.sharding.rule.ShardingRule;
+import org.apache.shardingsphere.sql.parser.binder.segment.table.TablesContext;
import org.apache.shardingsphere.sql.parser.sql.segment.dml.assignment.AssignmentSegment;
+import org.apache.shardingsphere.sql.parser.sql.segment.dml.column.ColumnSegment;
import org.apache.shardingsphere.sql.parser.sql.segment.dml.column.OnDuplicateKeyColumnsSegment;
+import org.apache.shardingsphere.sql.parser.sql.segment.dml.expr.subquery.SubquerySegment;
import org.apache.shardingsphere.sql.parser.sql.statement.dml.InsertStatement;
-import org.apache.shardingsphere.infra.exception.ShardingSphereException;
+import java.util.Collection;
import java.util.List;
import java.util.Optional;
@@ -33,11 +37,20 @@ import java.util.Optional;
public final class ShardingInsertStatementValidator implements ShardingStatementValidator<InsertStatement> {
@Override
- public void validate(final ShardingRule shardingRule, final InsertStatement sqlStatement, final List<Object> parameters) {
+ public void validate(final ShardingRule shardingRule, final InsertStatement sqlStatement, final TablesContext tablesContext, final List<Object> parameters) {
Optional<OnDuplicateKeyColumnsSegment> onDuplicateKeyColumnsSegment = sqlStatement.getOnDuplicateKeyColumns();
- if (onDuplicateKeyColumnsSegment.isPresent() && isUpdateShardingKey(shardingRule, onDuplicateKeyColumnsSegment.get(), sqlStatement.getTable().getTableName().getIdentifier().getValue())) {
+ String tableName = sqlStatement.getTable().getTableName().getIdentifier().getValue();
+ if (onDuplicateKeyColumnsSegment.isPresent() && isUpdateShardingKey(shardingRule, onDuplicateKeyColumnsSegment.get(), tableName)) {
throw new ShardingSphereException("INSERT INTO .... ON DUPLICATE KEY UPDATE can not support update for sharding column.");
}
+ Optional<SubquerySegment> insertSelectSegment = sqlStatement.getInsertSelect();
+ if (insertSelectSegment.isPresent() && isContainsKeyGenerateStrategy(shardingRule, tableName)
+ && !isContainsKeyGenerateColumn(shardingRule, sqlStatement.getColumns(), tableName)) {
+ throw new ShardingSphereException("INSERT INTO .... SELECT can not support applying keyGenerator to absent generateKeyColumn.");
+ }
+ if (insertSelectSegment.isPresent() && !isAllSameTables(tablesContext.getTableNames()) && !shardingRule.isAllBindingTables(tablesContext.getTableNames())) {
+ throw new ShardingSphereException("The table inserted and the table selected must be the same or bind tables.");
+ }
}
private boolean isUpdateShardingKey(final ShardingRule shardingRule, final OnDuplicateKeyColumnsSegment onDuplicateKeyColumnsSegment, final String tableName) {
@@ -48,4 +61,16 @@ public final class ShardingInsertStatementValidator implements ShardingStatement
}
return false;
}
+
+ private boolean isContainsKeyGenerateStrategy(final ShardingRule shardingRule, final String tableName) {
+ return shardingRule.findGenerateKeyColumnName(tableName).isPresent();
+ }
+
+ private boolean isContainsKeyGenerateColumn(final ShardingRule shardingRule, final Collection<ColumnSegment> columns, final String tableName) {
+ return columns.isEmpty() || columns.stream().anyMatch(each -> shardingRule.isGenerateKeyColumn(each.getIdentifier().getValue(), tableName));
+ }
+
+ private boolean isAllSameTables(final Collection<String> tableNames) {
+ return 1 == tableNames.stream().distinct().count();
+ }
}
diff --git a/shardingsphere-features/shardingsphere-sharding/shardingsphere-sharding-route/src/main/java/org/apache/shardingsphere/sharding/route/engine/validator/impl/ShardingUpdateStatementValidator.java b/shardingsphere-features/shardingsphere-sharding/shardingsphere-sharding-route/src/main/java/org/apache/shardingsphere/sharding/route/engine/validator/impl/ShardingUpdateStatementValidator.java
index 645f1f8..1ed8580 100644
--- a/shardingsphere-features/shardingsphere-sharding/shardingsphere-sharding-route/src/main/java/org/apache/shardingsphere/sharding/route/engine/validator/impl/ShardingUpdateStatementValidator.java
+++ b/shardingsphere-features/shardingsphere-sharding/shardingsphere-sharding-route/src/main/java/org/apache/shardingsphere/sharding/route/engine/validator/impl/ShardingUpdateStatementValidator.java
@@ -17,8 +17,10 @@
package org.apache.shardingsphere.sharding.route.engine.validator.impl;
-import org.apache.shardingsphere.sharding.rule.ShardingRule;
+import org.apache.shardingsphere.infra.exception.ShardingSphereException;
import org.apache.shardingsphere.sharding.route.engine.validator.ShardingStatementValidator;
+import org.apache.shardingsphere.sharding.rule.ShardingRule;
+import org.apache.shardingsphere.sql.parser.binder.segment.table.TablesContext;
import org.apache.shardingsphere.sql.parser.sql.segment.dml.assignment.AssignmentSegment;
import org.apache.shardingsphere.sql.parser.sql.segment.dml.expr.ExpressionSegment;
import org.apache.shardingsphere.sql.parser.sql.segment.dml.expr.simple.LiteralExpressionSegment;
@@ -30,7 +32,6 @@ import org.apache.shardingsphere.sql.parser.sql.segment.dml.predicate.value.Pred
import org.apache.shardingsphere.sql.parser.sql.segment.dml.predicate.value.PredicateInRightValue;
import org.apache.shardingsphere.sql.parser.sql.segment.dml.predicate.value.PredicateRightValue;
import org.apache.shardingsphere.sql.parser.sql.statement.dml.UpdateStatement;
-import org.apache.shardingsphere.infra.exception.ShardingSphereException;
import java.util.Collection;
import java.util.List;
@@ -42,7 +43,7 @@ import java.util.Optional;
public final class ShardingUpdateStatementValidator implements ShardingStatementValidator<UpdateStatement> {
@Override
- public void validate(final ShardingRule shardingRule, final UpdateStatement sqlStatement, final List<Object> parameters) {
+ public void validate(final ShardingRule shardingRule, final UpdateStatement sqlStatement, final TablesContext tablesContext, final List<Object> parameters) {
String tableName = sqlStatement.getTables().iterator().next().getTableName().getIdentifier().getValue();
for (AssignmentSegment each : sqlStatement.getSetAssignment().getAssignments()) {
String shardingColumn = each.getColumn().getIdentifier().getValue();
diff --git a/shardingsphere-features/shardingsphere-sharding/shardingsphere-sharding-route/src/test/java/org/apache/shardingsphere/sharding/route/engine/validator/impl/ShardingInsertStatementValidatorTest.java b/shardingsphere-features/shardingsphere-sharding/shardingsphere-sharding-route/src/test/java/org/apache/shardingsphere/sharding/route/engine/validator/impl/ShardingInsertStatementValidatorTest.java
index ce325da..db310c7 100644
--- a/shardingsphere-features/shardingsphere-sharding/shardingsphere-sharding-route/src/test/java/org/apache/shardingsphere/sharding/route/engine/validator/impl/ShardingInsertStatementValidatorTest.java
+++ b/shardingsphere-features/shardingsphere-sharding/shardingsphere-sharding-route/src/test/java/org/apache/shardingsphere/sharding/route/engine/validator/impl/ShardingInsertStatementValidatorTest.java
@@ -17,21 +17,29 @@
package org.apache.shardingsphere.sharding.route.engine.validator.impl;
+import org.apache.shardingsphere.infra.exception.ShardingSphereException;
import org.apache.shardingsphere.sharding.rule.ShardingRule;
+import org.apache.shardingsphere.sql.parser.binder.segment.table.TablesContext;
import org.apache.shardingsphere.sql.parser.sql.segment.dml.assignment.AssignmentSegment;
import org.apache.shardingsphere.sql.parser.sql.segment.dml.column.ColumnSegment;
+import org.apache.shardingsphere.sql.parser.sql.segment.dml.column.InsertColumnsSegment;
import org.apache.shardingsphere.sql.parser.sql.segment.dml.column.OnDuplicateKeyColumnsSegment;
import org.apache.shardingsphere.sql.parser.sql.segment.dml.expr.simple.ParameterMarkerExpressionSegment;
+import org.apache.shardingsphere.sql.parser.sql.segment.dml.expr.subquery.SubquerySegment;
import org.apache.shardingsphere.sql.parser.sql.segment.generic.table.SimpleTableSegment;
import org.apache.shardingsphere.sql.parser.sql.statement.dml.InsertStatement;
+import org.apache.shardingsphere.sql.parser.sql.statement.dml.SelectStatement;
import org.apache.shardingsphere.sql.parser.sql.value.identifier.IdentifierValue;
-import org.apache.shardingsphere.infra.exception.ShardingSphereException;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.mockito.Mock;
import org.mockito.junit.MockitoJUnitRunner;
+import java.util.Collection;
import java.util.Collections;
+import java.util.LinkedList;
+import java.util.List;
+import java.util.Optional;
import static org.mockito.Mockito.when;
@@ -44,13 +52,45 @@ public final class ShardingInsertStatementValidatorTest {
@Test
public void assertValidateOnDuplicateKeyWithoutShardingKey() {
when(shardingRule.isShardingColumn("id", "user")).thenReturn(false);
- new ShardingInsertStatementValidator().validate(shardingRule, createInsertStatement(), Collections.emptyList());
+ new ShardingInsertStatementValidator().validate(shardingRule, createInsertStatement(), createSingleTablesContext(), Collections.emptyList());
}
@Test(expected = ShardingSphereException.class)
public void assertValidateOnDuplicateKeyWithShardingKey() {
when(shardingRule.isShardingColumn("id", "user")).thenReturn(true);
- new ShardingInsertStatementValidator().validate(shardingRule, createInsertStatement(), Collections.emptyList());
+ new ShardingInsertStatementValidator().validate(shardingRule, createInsertStatement(), createSingleTablesContext(), Collections.emptyList());
+ }
+
+ @Test(expected = ShardingSphereException.class)
+ public void assertValidateInsertSelectWithoutKeyGenerateColumn() {
+ when(shardingRule.findGenerateKeyColumnName("user")).thenReturn(Optional.of("id"));
+ when(shardingRule.isGenerateKeyColumn("id", "user")).thenReturn(false);
+ new ShardingInsertStatementValidator().validate(shardingRule, createInsertStatement(), createSingleTablesContext(), Collections.emptyList());
+ }
+
+ @Test
+ public void assertValidateInsertSelectWithKeyGenerateColumn() {
+ when(shardingRule.findGenerateKeyColumnName("user")).thenReturn(Optional.of("id"));
+ when(shardingRule.isGenerateKeyColumn("id", "user")).thenReturn(true);
+ new ShardingInsertStatementValidator().validate(shardingRule, createInsertStatement(), createSingleTablesContext(), Collections.emptyList());
+ }
+
+ @Test(expected = ShardingSphereException.class)
+ public void assertValidateInsertSelectWithoutBindingTables() {
+ when(shardingRule.findGenerateKeyColumnName("user")).thenReturn(Optional.of("id"));
+ when(shardingRule.isGenerateKeyColumn("id", "user")).thenReturn(true);
+ TablesContext multiTablesContext = createMultiTablesContext();
+ when(shardingRule.isAllBindingTables(multiTablesContext.getTableNames())).thenReturn(false);
+ new ShardingInsertStatementValidator().validate(shardingRule, createInsertStatement(), multiTablesContext, Collections.emptyList());
+ }
+
+ @Test
+ public void assertValidateInsertSelectWithBindingTables() {
+ when(shardingRule.findGenerateKeyColumnName("user")).thenReturn(Optional.of("id"));
+ when(shardingRule.isGenerateKeyColumn("id", "user")).thenReturn(true);
+ TablesContext multiTablesContext = createMultiTablesContext();
+ when(shardingRule.isAllBindingTables(multiTablesContext.getTableNames())).thenReturn(true);
+ new ShardingInsertStatementValidator().validate(shardingRule, createInsertStatement(), multiTablesContext, Collections.emptyList());
}
private InsertStatement createInsertStatement() {
@@ -59,6 +99,23 @@ public final class ShardingInsertStatementValidatorTest {
ColumnSegment columnSegment = new ColumnSegment(0, 0, new IdentifierValue("id"));
AssignmentSegment assignmentSegment = new AssignmentSegment(0, 0, columnSegment, new ParameterMarkerExpressionSegment(0, 0, 1));
result.setOnDuplicateKeyColumns(new OnDuplicateKeyColumnsSegment(0, 0, Collections.singletonList(assignmentSegment)));
+ Collection<ColumnSegment> columns = new LinkedList<>();
+ columns.add(columnSegment);
+ result.setInsertColumns(new InsertColumnsSegment(0, 0, columns));
+ result.setInsertSelect(new SubquerySegment(0, 0, new SelectStatement()));
return result;
}
+
+ private TablesContext createSingleTablesContext() {
+ List<SimpleTableSegment> result = new LinkedList<>();
+ result.add(new SimpleTableSegment(0, 0, new IdentifierValue("user")));
+ return new TablesContext(result);
+ }
+
+ private TablesContext createMultiTablesContext() {
+ List<SimpleTableSegment> result = new LinkedList<>();
+ result.add(new SimpleTableSegment(0, 0, new IdentifierValue("user")));
+ result.add(new SimpleTableSegment(0, 0, new IdentifierValue("order")));
+ return new TablesContext(result);
+ }
}
diff --git a/shardingsphere-features/shardingsphere-sharding/shardingsphere-sharding-route/src/test/java/org/apache/shardingsphere/sharding/route/engine/validator/impl/ShardingUpdateStatementValidatorTest.java b/shardingsphere-features/shardingsphere-sharding/shardingsphere-sharding-route/src/test/java/org/apache/shardingsphere/sharding/route/engine/validator/impl/ShardingUpdateStatementValidatorTest.java
index f9e068f..078b6df 100644
--- a/shardingsphere-features/shardingsphere-sharding/shardingsphere-sharding-route/src/test/java/org/apache/shardingsphere/sharding/route/engine/validator/impl/ShardingUpdateStatementValidatorTest.java
+++ b/shardingsphere-features/shardingsphere-sharding/shardingsphere-sharding-route/src/test/java/org/apache/shardingsphere/sharding/route/engine/validator/impl/ShardingUpdateStatementValidatorTest.java
@@ -18,6 +18,7 @@
package org.apache.shardingsphere.sharding.route.engine.validator.impl;
import org.apache.shardingsphere.sharding.rule.ShardingRule;
+import org.apache.shardingsphere.sql.parser.binder.segment.table.TablesContext;
import org.apache.shardingsphere.sql.parser.sql.segment.dml.assignment.AssignmentSegment;
import org.apache.shardingsphere.sql.parser.sql.segment.dml.assignment.SetAssignmentSegment;
import org.apache.shardingsphere.sql.parser.sql.segment.dml.column.ColumnSegment;
@@ -52,34 +53,34 @@ public final class ShardingUpdateStatementValidatorTest {
@Test
public void assertValidateUpdateWithoutShardingKey() {
when(shardingRule.isShardingColumn("id", "user")).thenReturn(false);
- new ShardingUpdateStatementValidator().validate(shardingRule, createUpdateStatement(), Collections.emptyList());
+ new ShardingUpdateStatementValidator().validate(shardingRule, createUpdateStatement(), createTablesContext(), Collections.emptyList());
}
@Test(expected = ShardingSphereException.class)
public void assertValidateUpdateWithShardingKey() {
when(shardingRule.isShardingColumn("id", "user")).thenReturn(true);
- new ShardingUpdateStatementValidator().validate(shardingRule, createUpdateStatement(), Collections.emptyList());
+ new ShardingUpdateStatementValidator().validate(shardingRule, createUpdateStatement(), createTablesContext(), Collections.emptyList());
}
@Test
public void assertValidateUpdateWithoutShardingKeyAndParameters() {
when(shardingRule.isShardingColumn("id", "user")).thenReturn(false);
List<Object> parameters = Arrays.asList(1, 1);
- new ShardingUpdateStatementValidator().validate(shardingRule, createUpdateStatement(), parameters);
+ new ShardingUpdateStatementValidator().validate(shardingRule, createUpdateStatement(), createTablesContext(), parameters);
}
@Test
public void assertValidateUpdateWithShardingKeyAndShardingParameterEquals() {
when(shardingRule.isShardingColumn("id", "user")).thenReturn(true);
List<Object> parameters = Arrays.asList(1, 1);
- new ShardingUpdateStatementValidator().validate(shardingRule, createUpdateStatementAndParameters(1), parameters);
+ new ShardingUpdateStatementValidator().validate(shardingRule, createUpdateStatementAndParameters(1), createTablesContext(), parameters);
}
@Test(expected = ShardingSphereException.class)
public void assertValidateUpdateWithShardingKeyAndShardingParameterNotEquals() {
when(shardingRule.isShardingColumn("id", "user")).thenReturn(true);
List<Object> parameters = Arrays.asList(1, 1);
- new ShardingUpdateStatementValidator().validate(shardingRule, createUpdateStatementAndParameters(2), parameters);
+ new ShardingUpdateStatementValidator().validate(shardingRule, createUpdateStatementAndParameters(2), createTablesContext(), parameters);
}
private UpdateStatement createUpdateStatement() {
@@ -104,4 +105,9 @@ public final class ShardingUpdateStatementValidatorTest {
result.setWhere(where);
return result;
}
+
+ private TablesContext createTablesContext() {
+ SimpleTableSegment tableSegment = new SimpleTableSegment(0, 0, new IdentifierValue("user"));
+ return new TablesContext(tableSegment);
+ }
}
diff --git a/shardingsphere-sql-parser/shardingsphere-sql-parser-binder/src/main/java/org/apache/shardingsphere/sql/parser/binder/segment/insert/values/InsertSelectContext.java b/shardingsphere-sql-parser/shardingsphere-sql-parser-binder/src/main/java/org/apache/shardingsphere/sql/parser/binder/segment/insert/values/InsertSelectContext.java
new file mode 100644
index 0000000..79b74f4
--- /dev/null
+++ b/shardingsphere-sql-parser/shardingsphere-sql-parser-binder/src/main/java/org/apache/shardingsphere/sql/parser/binder/segment/insert/values/InsertSelectContext.java
@@ -0,0 +1,55 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.shardingsphere.sql.parser.binder.segment.insert.values;
+
+import lombok.Getter;
+import lombok.ToString;
+import org.apache.shardingsphere.sql.parser.binder.statement.dml.SelectStatementContext;
+
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.List;
+
+/**
+ * Insert select context.
+ */
+@Getter
+@ToString
+public class InsertSelectContext {
+
+ private final int parametersCount;
+
+ private final List<Object> parameters;
+
+ private final SelectStatementContext selectStatementContext;
+
+ public InsertSelectContext(final SelectStatementContext selectStatementContext, final List<Object> parameters, final int parametersOffset) {
+ parametersCount = selectStatementContext.getSqlStatement().getParameterCount();
+ this.selectStatementContext = selectStatementContext;
+ this.parameters = getParameters(parameters, parametersOffset);
+ }
+
+ private List<Object> getParameters(final List<Object> parameters, final int parametersOffset) {
+ if (0 == parametersCount) {
+ return Collections.emptyList();
+ }
+ List<Object> result = new ArrayList<>(parametersCount);
+ result.addAll(parameters.subList(parametersOffset, parametersOffset + parametersCount));
+ return result;
+ }
+}
diff --git a/shardingsphere-sql-parser/shardingsphere-sql-parser-binder/src/main/java/org/apache/shardingsphere/sql/parser/binder/segment/table/TablesContext.java b/shardingsphere-sql-parser/shardingsphere-sql-parser-binder/src/main/java/org/apache/shardingsphere/sql/parser/binder/segment/table/TablesContext.java
index 2054fb9..0990a56 100644
--- a/shardingsphere-sql-parser/shardingsphere-sql-parser-binder/src/main/java/org/apache/shardingsphere/sql/parser/binder/segment/table/TablesContext.java
+++ b/shardingsphere-sql-parser/shardingsphere-sql-parser-binder/src/main/java/org/apache/shardingsphere/sql/parser/binder/segment/table/TablesContext.java
@@ -17,6 +17,7 @@
package org.apache.shardingsphere.sql.parser.binder.segment.table;
+import lombok.Getter;
import lombok.RequiredArgsConstructor;
import lombok.ToString;
import org.apache.shardingsphere.sql.parser.binder.metadata.schema.SchemaMetaData;
@@ -34,6 +35,7 @@ import java.util.Optional;
*/
@RequiredArgsConstructor
@ToString
+@Getter
public final class TablesContext {
private final Collection<SimpleTableSegment> tables;
diff --git a/shardingsphere-sql-parser/shardingsphere-sql-parser-binder/src/main/java/org/apache/shardingsphere/sql/parser/binder/statement/dml/InsertStatementContext.java b/shardingsphere-sql-parser/shardingsphere-sql-parser-binder/src/main/java/org/apache/shardingsphere/sql/parser/binder/statement/dml/InsertStatementContext.java
index b10bae5..fb80c27 100644
--- a/shardingsphere-sql-parser/shardingsphere-sql-parser-binder/src/main/java/org/apache/shardingsphere/sql/parser/binder/statement/dml/InsertStatementContext.java
+++ b/shardingsphere-sql-parser/shardingsphere-sql-parser-binder/src/main/java/org/apache/shardingsphere/sql/parser/binder/statement/dml/InsertStatementContext.java
@@ -22,6 +22,7 @@ import lombok.ToString;
import org.apache.shardingsphere.sql.parser.binder.metadata.schema.SchemaMetaData;
import org.apache.shardingsphere.sql.parser.binder.segment.insert.keygen.GeneratedKeyContext;
import org.apache.shardingsphere.sql.parser.binder.segment.insert.keygen.engine.GeneratedKeyContextEngine;
+import org.apache.shardingsphere.sql.parser.binder.segment.insert.values.InsertSelectContext;
import org.apache.shardingsphere.sql.parser.binder.segment.insert.values.InsertValueContext;
import org.apache.shardingsphere.sql.parser.binder.segment.insert.values.OnDuplicateUpdateContext;
import org.apache.shardingsphere.sql.parser.binder.segment.table.TablesContext;
@@ -29,12 +30,12 @@ import org.apache.shardingsphere.sql.parser.binder.statement.CommonSQLStatementC
import org.apache.shardingsphere.sql.parser.binder.type.TableAvailable;
import org.apache.shardingsphere.sql.parser.sql.segment.dml.assignment.AssignmentSegment;
import org.apache.shardingsphere.sql.parser.sql.segment.dml.expr.ExpressionSegment;
+import org.apache.shardingsphere.sql.parser.sql.segment.dml.expr.subquery.SubquerySegment;
import org.apache.shardingsphere.sql.parser.sql.segment.generic.table.SimpleTableSegment;
import org.apache.shardingsphere.sql.parser.sql.statement.dml.InsertStatement;
import java.util.ArrayList;
import java.util.Collection;
-import java.util.Collections;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
@@ -54,20 +55,32 @@ public final class InsertStatementContext extends CommonSQLStatementContext<Inse
private final List<InsertValueContext> insertValueContexts;
+ private final InsertSelectContext insertSelectContext;
+
private final OnDuplicateUpdateContext onDuplicateKeyUpdateValueContext;
private final GeneratedKeyContext generatedKeyContext;
public InsertStatementContext(final SchemaMetaData schemaMetaData, final List<Object> parameters, final InsertStatement sqlStatement) {
super(sqlStatement);
- tablesContext = new TablesContext(sqlStatement.getTable());
columnNames = sqlStatement.useDefaultColumns() ? schemaMetaData.getAllColumnNames(sqlStatement.getTable().getTableName().getIdentifier().getValue()) : sqlStatement.getColumnNames();
AtomicInteger parametersOffset = new AtomicInteger(0);
insertValueContexts = getInsertValueContexts(parameters, parametersOffset);
+ insertSelectContext = getInsertSelectContext(schemaMetaData, parameters, parametersOffset).orElse(null);
+ tablesContext = getTablesContext(sqlStatement);
onDuplicateKeyUpdateValueContext = getOnDuplicateKeyUpdateValueContext(parameters, parametersOffset).orElse(null);
generatedKeyContext = new GeneratedKeyContextEngine(schemaMetaData).createGenerateKeyContext(parameters, sqlStatement).orElse(null);
}
+ private TablesContext getTablesContext(final InsertStatement sqlStatement) {
+ List<SimpleTableSegment> result = new LinkedList<>();
+ result.add(sqlStatement.getTable());
+ if (sqlStatement.getInsertSelect().isPresent()) {
+ result.addAll(insertSelectContext.getSelectStatementContext().getSimpleTableSegments());
+ }
+ return new TablesContext(result);
+ }
+
private List<InsertValueContext> getInsertValueContexts(final List<Object> parameters, final AtomicInteger parametersOffset) {
List<InsertValueContext> result = new LinkedList<>();
for (Collection<ExpressionSegment> each : getSqlStatement().getAllValueExpressions()) {
@@ -78,6 +91,17 @@ public final class InsertStatementContext extends CommonSQLStatementContext<Inse
return result;
}
+ private Optional<InsertSelectContext> getInsertSelectContext(final SchemaMetaData schemaMetaData, final List<Object> parameters, final AtomicInteger parametersOffset) {
+ if (!getSqlStatement().getInsertSelect().isPresent()) {
+ return Optional.empty();
+ }
+ SubquerySegment insertSelectSegment = getSqlStatement().getInsertSelect().get();
+ SelectStatementContext selectStatementContext = new SelectStatementContext(schemaMetaData, parameters, insertSelectSegment.getSelect());
+ InsertSelectContext insertSelectContext = new InsertSelectContext(selectStatementContext, parameters, parametersOffset.get());
+ parametersOffset.addAndGet(insertSelectContext.getParametersCount());
+ return Optional.of(insertSelectContext);
+ }
+
private Optional<OnDuplicateUpdateContext> getOnDuplicateKeyUpdateValueContext(final List<Object> parameters, final AtomicInteger parametersOffset) {
if (!getSqlStatement().getOnDuplicateKeyColumns().isPresent()) {
return Optional.empty();
@@ -107,6 +131,9 @@ public final class InsertStatementContext extends CommonSQLStatementContext<Inse
for (InsertValueContext each : insertValueContexts) {
result.add(each.getParameters());
}
+ if (null != insertSelectContext) {
+ result.add(insertSelectContext.getParameters());
+ }
return result;
}
@@ -133,6 +160,6 @@ public final class InsertStatementContext extends CommonSQLStatementContext<Inse
@Override
public Collection<SimpleTableSegment> getAllTables() {
- return Collections.singletonList(getSqlStatement().getTable());
+ return tablesContext.getTables();
}
}
diff --git a/shardingsphere-sql-parser/shardingsphere-sql-parser-statement/src/main/java/org/apache/shardingsphere/sql/parser/sql/statement/dml/InsertStatement.java b/shardingsphere-sql-parser/shardingsphere-sql-parser-statement/src/main/java/org/apache/shardingsphere/sql/parser/sql/statement/dml/InsertStatement.java
index 1bd5645..946b5e0 100644
--- a/shardingsphere-sql-parser/shardingsphere-sql-parser-statement/src/main/java/org/apache/shardingsphere/sql/parser/sql/statement/dml/InsertStatement.java
+++ b/shardingsphere-sql-parser/shardingsphere-sql-parser-statement/src/main/java/org/apache/shardingsphere/sql/parser/sql/statement/dml/InsertStatement.java
@@ -155,6 +155,9 @@ public final class InsertStatement extends DMLStatement {
if (null != setAssignment) {
return setAssignment.getAssignments().size();
}
+ if (null != insertSelect) {
+ return insertSelect.getSelect().getProjections().getProjections().size();
+ }
return 0;
}