You are viewing a plain text version of this content. The canonical link for it is here.
Posted to notifications@shardingsphere.apache.org by pa...@apache.org on 2020/08/19 06:47:02 UTC
[shardingsphere] branch master updated: fix rewrite for subquery
(#6837)
This is an automated email from the ASF dual-hosted git repository.
panjuan pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/shardingsphere.git
The following commit(s) were added to refs/heads/master by this push:
new 4dcca4f fix rewrite for subquery (#6837)
4dcca4f is described below
commit 4dcca4f9e530e794f1698da9a4584da1264d6c2b
Author: JingShang Lu <lu...@apache.org>
AuthorDate: Wed Aug 19 14:46:49 2020 +0800
fix rewrite for subquery (#6837)
* fix rewrite for subquery
* fix
* check function name
* fix
* rename func name
---
.../dql/groupby/GroupByMemoryMergedResult.java | 2 +-
.../src/test/resources/sharding/select.xml | 5 +
.../statement/dml/SelectStatementContext.java | 247 +----------------
.../src/main/antlr4/imports/mysql/DDLStatement.g4 | 4 +-
.../sql/parser/mysql/visitor/MySQLVisitor.java | 2 +-
.../sql/parser/sql/util/TableExtractUtils.java | 294 +++++++++++++++++++++
.../test/resources/sql/supported/ddl/create.xml | 2 +-
7 files changed, 307 insertions(+), 249 deletions(-)
diff --git a/shardingsphere-features/shardingsphere-sharding/shardingsphere-sharding-merge/src/main/java/org/apache/shardingsphere/sharding/merge/dql/groupby/GroupByMemoryMergedResult.java b/shardingsphere-features/shardingsphere-sharding/shardingsphere-sharding-merge/src/main/java/org/apache/shardingsphere/sharding/merge/dql/groupby/GroupByMemoryMergedResult.java
index 6c7b9a8..0f5295a 100644
--- a/shardingsphere-features/shardingsphere-sharding/shardingsphere-sharding-merge/src/main/java/org/apache/shardingsphere/sharding/merge/dql/groupby/GroupByMemoryMergedResult.java
+++ b/shardingsphere-features/shardingsphere-sharding/shardingsphere-sharding-merge/src/main/java/org/apache/shardingsphere/sharding/merge/dql/groupby/GroupByMemoryMergedResult.java
@@ -126,7 +126,7 @@ public final class GroupByMemoryMergedResult extends MemoryMergedResult<Sharding
private boolean getValueCaseSensitiveFromTables(final QueryResult queryResult, final SelectStatementContext selectStatementContext,
final SchemaMetaData schemaMetaData, final int columnIndex) throws SQLException {
- for (SimpleTableSegment each : selectStatementContext.getAllTables()) {
+ for (SimpleTableSegment each : selectStatementContext.getSimpleTableSegments()) {
String tableName = each.getTableName().getIdentifier().getValue();
TableMetaData tableMetaData = schemaMetaData.get(tableName);
Map<String, ColumnMetaData> columns = tableMetaData.getColumns();
diff --git a/shardingsphere-features/shardingsphere-sharding/shardingsphere-sharding-rewrite/src/test/resources/sharding/select.xml b/shardingsphere-features/shardingsphere-sharding/shardingsphere-sharding-rewrite/src/test/resources/sharding/select.xml
index 26328d1..d26049d 100644
--- a/shardingsphere-features/shardingsphere-sharding/shardingsphere-sharding-rewrite/src/test/resources/sharding/select.xml
+++ b/shardingsphere-features/shardingsphere-sharding/shardingsphere-sharding-rewrite/src/test/resources/sharding/select.xml
@@ -58,6 +58,11 @@
<input sql="SELECT * FROM (select b.account_id from (select t_account.account_id from t_account) b where b.account_id=?) a WHERE account_id = 100" parameters="100" />
<output sql="SELECT * FROM (select b.account_id from (select t_account_0.account_id from t_account_0) b where b.account_id=?) a WHERE account_id = 100" parameters="100" />
</rewrite-assertion>
+
+ <rewrite-assertion id="select_with_subquery_in_projection_and_where" db-type="MySQL">
+ <input sql="SELECT (select id from t_account limit 1) as myid FROM (select b.account_id from (select t_account.account_id from t_account) b where b.account_id=?) a WHERE account_id >= (select account_id from t_account limit 1)" parameters="100"/>
+ <output sql="SELECT (select id from t_account_0 limit 1) as myid FROM (select b.account_id from (select t_account_0.account_id from t_account_0) b where b.account_id=?) a WHERE account_id >= (select account_id from t_account_0 limit 1)" parameters="100"/>
+ </rewrite-assertion>
<rewrite-assertion id="select_without_sharding_value_for_parameters">
<input sql="SELECT * FROM db.t_account WHERE amount = ?" parameters="1000" />
diff --git a/shardingsphere-sql-parser/shardingsphere-sql-parser-binder/src/main/java/org/apache/shardingsphere/sql/parser/binder/statement/dml/SelectStatementContext.java b/shardingsphere-sql-parser/shardingsphere-sql-parser-binder/src/main/java/org/apache/shardingsphere/sql/parser/binder/statement/dml/SelectStatementContext.java
index 64ed1dc..b401870 100644
--- a/shardingsphere-sql-parser/shardingsphere-sql-parser-binder/src/main/java/org/apache/shardingsphere/sql/parser/binder/statement/dml/SelectStatementContext.java
+++ b/shardingsphere-sql-parser/shardingsphere-sql-parser-binder/src/main/java/org/apache/shardingsphere/sql/parser/binder/statement/dml/SelectStatementContext.java
@@ -36,34 +36,18 @@ import org.apache.shardingsphere.sql.parser.binder.segment.table.TablesContext;
import org.apache.shardingsphere.sql.parser.binder.statement.CommonSQLStatementContext;
import org.apache.shardingsphere.sql.parser.binder.type.TableAvailable;
import org.apache.shardingsphere.sql.parser.binder.type.WhereAvailable;
-import org.apache.shardingsphere.sql.parser.sql.predicate.PredicateExtractor;
-import org.apache.shardingsphere.sql.parser.sql.segment.dml.JoinSpecificationSegment;
-import org.apache.shardingsphere.sql.parser.sql.segment.dml.JoinedTableSegment;
-import org.apache.shardingsphere.sql.parser.sql.segment.dml.TableFactorSegment;
-import org.apache.shardingsphere.sql.parser.sql.segment.dml.TableReferenceSegment;
-import org.apache.shardingsphere.sql.parser.sql.segment.dml.column.ColumnSegment;
-import org.apache.shardingsphere.sql.parser.sql.segment.dml.item.ColumnProjectionSegment;
-import org.apache.shardingsphere.sql.parser.sql.segment.dml.item.ProjectionSegment;
-import org.apache.shardingsphere.sql.parser.sql.segment.dml.item.ProjectionsSegment;
import org.apache.shardingsphere.sql.parser.sql.segment.dml.order.item.ColumnOrderByItemSegment;
import org.apache.shardingsphere.sql.parser.sql.segment.dml.order.item.ExpressionOrderByItemSegment;
import org.apache.shardingsphere.sql.parser.sql.segment.dml.order.item.IndexOrderByItemSegment;
-import org.apache.shardingsphere.sql.parser.sql.segment.dml.order.item.OrderByItemSegment;
import org.apache.shardingsphere.sql.parser.sql.segment.dml.order.item.TextOrderByItemSegment;
-import org.apache.shardingsphere.sql.parser.sql.segment.dml.predicate.AndPredicate;
-import org.apache.shardingsphere.sql.parser.sql.segment.dml.predicate.PredicateSegment;
import org.apache.shardingsphere.sql.parser.sql.segment.dml.predicate.WhereSegment;
-import org.apache.shardingsphere.sql.parser.sql.segment.generic.OwnerAvailable;
-import org.apache.shardingsphere.sql.parser.sql.segment.generic.OwnerSegment;
import org.apache.shardingsphere.sql.parser.sql.segment.generic.table.SimpleTableSegment;
-import org.apache.shardingsphere.sql.parser.sql.segment.generic.table.SubqueryTableSegment;
-import org.apache.shardingsphere.sql.parser.sql.segment.generic.table.TableSegment;
import org.apache.shardingsphere.sql.parser.sql.statement.dml.SelectStatement;
import org.apache.shardingsphere.sql.parser.sql.util.SQLUtil;
+import org.apache.shardingsphere.sql.parser.sql.util.TableExtractUtils;
import org.apache.shardingsphere.sql.parser.sql.util.WhereSegmentExtractUtils;
import java.util.Collection;
-import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Optional;
@@ -194,68 +178,7 @@ public final class SelectStatementContext extends CommonSQLStatementContext<Sele
@Override
public Collection<SimpleTableSegment> getAllTables() {
- return getTableFromSelect(getSqlStatement());
- }
-
- private Collection<SimpleTableSegment> getAllTablesFromWhere(final WhereSegment where, final Collection<TableSegment> tableSegments) {
- Collection<SimpleTableSegment> result = new LinkedList<>();
- for (AndPredicate each : where.getAndPredicates()) {
- for (PredicateSegment predicate : each.getPredicates()) {
- result.addAll(new PredicateExtractor(tableSegments, predicate).extractTables());
- }
- }
- return result;
- }
-
- private Collection<SimpleTableSegment> getAllTablesFromProjections(final ProjectionsSegment projections, final Collection<TableSegment> tableSegments) {
- Collection<SimpleTableSegment> result = new LinkedList<>();
- for (ProjectionSegment each : projections.getProjections()) {
- Optional<SimpleTableSegment> table = getTableSegment(each, tableSegments);
- table.ifPresent(result::add);
- }
- return result;
- }
-
- private Optional<SimpleTableSegment> getTableSegment(final ProjectionSegment each, final Collection<TableSegment> tableSegments) {
- Optional<OwnerSegment> owner = getTableOwner(each);
- if (owner.isPresent() && isTable(owner.get(), tableSegments)) {
- return Optional .of(new SimpleTableSegment(owner.get().getStartIndex(), owner.get().getStopIndex(), owner.get().getIdentifier()));
- }
- return Optional.empty();
- }
-
- private Optional<OwnerSegment> getTableOwner(final ProjectionSegment each) {
- if (each instanceof OwnerAvailable) {
- return ((OwnerAvailable) each).getOwner();
- }
- if (each instanceof ColumnProjectionSegment) {
- return ((ColumnProjectionSegment) each).getColumn().getOwner();
- }
- return Optional.empty();
- }
-
- private Collection<SimpleTableSegment> getAllTablesFromOrderByItems(final Collection<OrderByItemSegment> orderByItems, final Collection<TableSegment> tableSegments) {
- Collection<SimpleTableSegment> result = new LinkedList<>();
- for (OrderByItemSegment each : orderByItems) {
- if (each instanceof ColumnOrderByItemSegment) {
- Optional<OwnerSegment> owner = ((ColumnOrderByItemSegment) each).getColumn().getOwner();
- if (owner.isPresent() && isTable(owner.get(), tableSegments)) {
- Preconditions.checkState(((ColumnOrderByItemSegment) each).getColumn().getOwner().isPresent());
- OwnerSegment segment = ((ColumnOrderByItemSegment) each).getColumn().getOwner().get();
- result.add(new SimpleTableSegment(segment.getStartIndex(), segment.getStopIndex(), segment.getIdentifier()));
- }
- }
- }
- return result;
- }
-
- private boolean isTable(final OwnerSegment owner, final Collection<TableSegment> tables) {
- for (TableSegment each : tables) {
- if (owner.getIdentifier().getValue().equals(each.getAlias().orElse(null))) {
- return false;
- }
- }
- return true;
+ return TableExtractUtils.getTablesFromSelect(getSqlStatement());
}
@Override
@@ -268,170 +191,6 @@ public final class SelectStatementContext extends CommonSQLStatementContext<Sele
* @return tables.
*/
public Collection<SimpleTableSegment> getSimpleTableSegments() {
- Collection<TableSegment> tables = getTables();
- Collection<SimpleTableSegment> result = new LinkedList<>();
- for (TableSegment each : tables) {
- if (each instanceof SimpleTableSegment) {
- result.add((SimpleTableSegment) each);
- } else {
- result.addAll(getRealTableFromSelect(((SubqueryTableSegment) each).getSubquery().getSelect()));
- }
- }
- return result;
- }
-
- private Collection<SimpleTableSegment> getTableFromSelect(final SelectStatement selectStatement) {
- Collection<SimpleTableSegment> result = new LinkedList<>();
- Collection<TableSegment> realTables = new LinkedList<>();
- Collection<TableSegment> allTables = new LinkedList<>();
- for (TableReferenceSegment each : selectStatement.getTableReferences()) {
- allTables.addAll(getTablesFromTableReference(each));
- realTables.addAll(getRealTablesFromTableReference(each));
- }
- if (selectStatement.getWhere().isPresent()) {
- allTables.addAll(getAllTablesFromWhere(selectStatement.getWhere().get(), realTables));
- }
- result.addAll(getAllTablesFromProjections(selectStatement.getProjections(), realTables));
- if (getSqlStatement().getGroupBy().isPresent()) {
- result.addAll(getAllTablesFromOrderByItems(getSqlStatement().getGroupBy().get().getGroupByItems(), realTables));
- }
- if (getSqlStatement().getOrderBy().isPresent()) {
- result.addAll(getAllTablesFromOrderByItems(getSqlStatement().getOrderBy().get().getOrderByItems(), realTables));
- }
- for (TableSegment each : allTables) {
- if (each instanceof SubqueryTableSegment) {
- result.addAll(getTableFromSelect(((SubqueryTableSegment) each).getSubquery().getSelect()));
- } else {
- result.add((SimpleTableSegment) each);
- }
- }
- return result;
- }
-
- private Collection<SimpleTableSegment> getRealTableFromSelect(final SelectStatement selectStatement) {
- Collection<SimpleTableSegment> result = new LinkedList<>();
- Collection<TableSegment> realTables = new LinkedList<>();
- for (TableReferenceSegment each : selectStatement.getTableReferences()) {
- realTables.addAll(getRealTablesFromTableReference(each));
- }
- for (TableSegment each : realTables) {
- if (each instanceof SubqueryTableSegment) {
- result.addAll(getRealTableFromSelect(((SubqueryTableSegment) each).getSubquery().getSelect()));
- } else {
- result.add((SimpleTableSegment) each);
- }
- }
- return result;
- }
-
- private Collection<TableSegment> getTables() {
- SelectStatement selectStatement = getSqlStatement();
- Collection<TableSegment> result = new LinkedList<>();
- for (TableReferenceSegment each : selectStatement.getTableReferences()) {
- result.addAll(getRealTablesFromTableReference(each));
- }
- return result;
- }
-
- private Collection<TableSegment> getTablesFromTableFactor(final TableFactorSegment tableFactorSegment) {
- Collection<TableSegment> result = new LinkedList<>();
- if (null != tableFactorSegment.getTable() && tableFactorSegment.getTable() instanceof SimpleTableSegment) {
- result.add(tableFactorSegment.getTable());
- }
- if (null != tableFactorSegment.getTable() && tableFactorSegment.getTable() instanceof SubqueryTableSegment) {
- result.add(tableFactorSegment.getTable());
- }
- if (null != tableFactorSegment.getTableReferences() && !tableFactorSegment.getTableReferences().isEmpty()) {
- for (TableReferenceSegment each: tableFactorSegment.getTableReferences()) {
- result.addAll(getTablesFromTableReference(each));
- }
- }
- return result;
- }
-
- private Collection<TableSegment> getRealTablesFromTableFactor(final TableFactorSegment tableFactorSegment) {
- Collection<TableSegment> result = new LinkedList<>();
- if (null != tableFactorSegment.getTable() && tableFactorSegment.getTable() instanceof SimpleTableSegment) {
- result.add(tableFactorSegment.getTable());
- }
- if (null != tableFactorSegment.getTable() && tableFactorSegment.getTable() instanceof SubqueryTableSegment) {
- result.add(tableFactorSegment.getTable());
- }
- if (null != tableFactorSegment.getTableReferences() && !tableFactorSegment.getTableReferences().isEmpty()) {
- for (TableReferenceSegment each: tableFactorSegment.getTableReferences()) {
- result.addAll(getRealTablesFromTableReference(each));
- }
- }
- return result;
- }
-
- private Collection<TableSegment> getTablesFromTableReference(final TableReferenceSegment tableReferenceSegment) {
- Collection<TableSegment> result = new LinkedList<>();
- if (null != tableReferenceSegment.getTableFactor()) {
- result.addAll(getTablesFromTableFactor(tableReferenceSegment.getTableFactor()));
- }
- if (null != tableReferenceSegment.getJoinedTables()) {
- for (JoinedTableSegment each : tableReferenceSegment.getJoinedTables()) {
- result.addAll(getTablesFromJoinTable(each, result));
- }
- }
- return result;
- }
-
- private Collection<TableSegment> getRealTablesFromTableReference(final TableReferenceSegment tableReferenceSegment) {
- Collection<TableSegment> result = new LinkedList<>();
- if (null != tableReferenceSegment.getTableFactor()) {
- result.addAll(getRealTablesFromTableFactor(tableReferenceSegment.getTableFactor()));
- }
- if (null != tableReferenceSegment.getJoinedTables()) {
- for (JoinedTableSegment each : tableReferenceSegment.getJoinedTables()) {
- result.addAll(getRealTablesFromJoinTable(each));
- }
- }
- return result;
- }
-
- private Collection<TableSegment> getTablesFromJoinTable(final JoinedTableSegment joinedTableSegment, final Collection<TableSegment> tableSegments) {
- Collection<TableSegment> result = new LinkedList<>();
- Collection<TableSegment> realTables = new LinkedList<>();
- realTables.addAll(tableSegments);
- if (null != joinedTableSegment.getTableFactor()) {
- result.addAll(getTablesFromTableFactor(joinedTableSegment.getTableFactor()));
- realTables.addAll(getTablesFromTableFactor(joinedTableSegment.getTableFactor()));
- }
- if (null != joinedTableSegment.getJoinSpecification()) {
- result.addAll(getTablesFromJoinSpecification(joinedTableSegment.getJoinSpecification(), realTables));
- }
- return result;
- }
-
- private Collection<TableSegment> getRealTablesFromJoinTable(final JoinedTableSegment joinedTableSegment) {
- Collection<TableSegment> result = new LinkedList<>();
- if (null != joinedTableSegment.getTableFactor()) {
- result.addAll(getTablesFromTableFactor(joinedTableSegment.getTableFactor()));
- }
- return result;
- }
-
- private Collection<SimpleTableSegment> getTablesFromJoinSpecification(final JoinSpecificationSegment joinSpecificationSegment, final Collection<TableSegment> tableSegments) {
- Collection<SimpleTableSegment> result = new LinkedList<>();
- Collection<AndPredicate> andPredicates = joinSpecificationSegment.getAndPredicates();
- for (AndPredicate each : andPredicates) {
- for (PredicateSegment e : each.getPredicates()) {
- if (null != e.getColumn() && (e.getColumn().getOwner().isPresent())) {
- OwnerSegment ownerSegment = e.getColumn().getOwner().get();
- if (isTable(ownerSegment, tableSegments)) {
- result.add(new SimpleTableSegment(ownerSegment.getStartIndex(), ownerSegment.getStopIndex(), ownerSegment.getIdentifier()));
- }
- }
- if (null != e.getRightValue() && (e.getRightValue() instanceof ColumnSegment) && ((ColumnSegment) e.getRightValue()).getOwner().isPresent()) {
- OwnerSegment ownerSegment = ((ColumnSegment) e.getRightValue()).getOwner().get();
- if (isTable(ownerSegment, tableSegments)) {
- result.add(new SimpleTableSegment(ownerSegment.getStartIndex(), ownerSegment.getStopIndex(), ownerSegment.getIdentifier()));
- }
- }
- }
- }
- return result;
+ return TableExtractUtils.getSimpleTableFromSelect(getSqlStatement());
}
}
diff --git a/shardingsphere-sql-parser/shardingsphere-sql-parser-dialect/shardingsphere-sql-parser-mysql/src/main/antlr4/imports/mysql/DDLStatement.g4 b/shardingsphere-sql-parser/shardingsphere-sql-parser-dialect/shardingsphere-sql-parser-mysql/src/main/antlr4/imports/mysql/DDLStatement.g4
index b89ae1f..ef15196 100644
--- a/shardingsphere-sql-parser/shardingsphere-sql-parser-dialect/shardingsphere-sql-parser-mysql/src/main/antlr4/imports/mysql/DDLStatement.g4
+++ b/shardingsphere-sql-parser/shardingsphere-sql-parser-dialect/shardingsphere-sql-parser-mysql/src/main/antlr4/imports/mysql/DDLStatement.g4
@@ -28,8 +28,8 @@ partitionClause
;
partitionTypeDef
- : LINEAR KEY partitionKeyAlgorithm? columnNames
- | LINEAR HASH LP_ bitExpr RP_
+ : LINEAR? KEY partitionKeyAlgorithm? columnNames
+ | LINEAR? HASH LP_ bitExpr RP_
| (RANGE | LIST) (LP_ bitExpr RP_ | COLUMNS columnNames )
;
diff --git a/shardingsphere-sql-parser/shardingsphere-sql-parser-dialect/shardingsphere-sql-parser-mysql/src/main/java/org/apache/shardingsphere/sql/parser/mysql/visitor/MySQLVisitor.java b/shardingsphere-sql-parser/shardingsphere-sql-parser-dialect/shardingsphere-sql-parser-mysql/src/main/java/org/apache/shardingsphere/sql/parser/mysql/visitor/MySQLVisitor.java
index f4192cf..7a3be40 100644
--- a/shardingsphere-sql-parser/shardingsphere-sql-parser-dialect/shardingsphere-sql-parser-mysql/src/main/java/org/apache/shardingsphere/sql/parser/mysql/visitor/MySQLVisitor.java
+++ b/shardingsphere-sql-parser/shardingsphere-sql-parser-dialect/shardingsphere-sql-parser-mysql/src/main/java/org/apache/shardingsphere/sql/parser/mysql/visitor/MySQLVisitor.java
@@ -289,7 +289,7 @@ public abstract class MySQLVisitor extends MySQLStatementBaseVisitor<ASTNode> {
private ASTNode createPredicateRightValue(final BooleanPrimaryContext ctx) {
if (null != ctx.subquery()) {
- new SubquerySegment(ctx.getStart().getStartIndex(), ctx.getStop().getStopIndex(), (SelectStatement) visit(ctx.subquery()));
+ return new SubquerySegment(ctx.getStart().getStartIndex(), ctx.getStop().getStopIndex(), (SelectStatement) visit(ctx.subquery()));
}
ASTNode rightValue = visit(ctx.predicate());
return createPredicateRightValue(ctx, rightValue);
diff --git a/shardingsphere-sql-parser/shardingsphere-sql-parser-statement/src/main/java/org/apache/shardingsphere/sql/parser/sql/util/TableExtractUtils.java b/shardingsphere-sql-parser/shardingsphere-sql-parser-statement/src/main/java/org/apache/shardingsphere/sql/parser/sql/util/TableExtractUtils.java
new file mode 100644
index 0000000..38467c4
--- /dev/null
+++ b/shardingsphere-sql-parser/shardingsphere-sql-parser-statement/src/main/java/org/apache/shardingsphere/sql/parser/sql/util/TableExtractUtils.java
@@ -0,0 +1,294 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.shardingsphere.sql.parser.sql.util;
+
+import com.google.common.base.Preconditions;
+import org.apache.shardingsphere.sql.parser.sql.segment.dml.JoinSpecificationSegment;
+import org.apache.shardingsphere.sql.parser.sql.segment.dml.JoinedTableSegment;
+import org.apache.shardingsphere.sql.parser.sql.segment.dml.TableFactorSegment;
+import org.apache.shardingsphere.sql.parser.sql.segment.dml.TableReferenceSegment;
+import org.apache.shardingsphere.sql.parser.sql.segment.dml.column.ColumnSegment;
+import org.apache.shardingsphere.sql.parser.sql.segment.dml.expr.subquery.SubqueryExpressionSegment;
+import org.apache.shardingsphere.sql.parser.sql.segment.dml.item.ColumnProjectionSegment;
+import org.apache.shardingsphere.sql.parser.sql.segment.dml.item.ProjectionSegment;
+import org.apache.shardingsphere.sql.parser.sql.segment.dml.item.ProjectionsSegment;
+import org.apache.shardingsphere.sql.parser.sql.segment.dml.item.SubqueryProjectionSegment;
+import org.apache.shardingsphere.sql.parser.sql.segment.dml.order.item.ColumnOrderByItemSegment;
+import org.apache.shardingsphere.sql.parser.sql.segment.dml.order.item.OrderByItemSegment;
+import org.apache.shardingsphere.sql.parser.sql.segment.dml.predicate.AndPredicate;
+import org.apache.shardingsphere.sql.parser.sql.segment.dml.predicate.PredicateSegment;
+import org.apache.shardingsphere.sql.parser.sql.segment.dml.predicate.WhereSegment;
+import org.apache.shardingsphere.sql.parser.sql.segment.dml.predicate.value.PredicateCompareRightValue;
+import org.apache.shardingsphere.sql.parser.sql.segment.generic.OwnerAvailable;
+import org.apache.shardingsphere.sql.parser.sql.segment.generic.OwnerSegment;
+import org.apache.shardingsphere.sql.parser.sql.segment.generic.table.SimpleTableSegment;
+import org.apache.shardingsphere.sql.parser.sql.segment.generic.table.SubqueryTableSegment;
+import org.apache.shardingsphere.sql.parser.sql.segment.generic.table.TableSegment;
+import org.apache.shardingsphere.sql.parser.sql.statement.dml.SelectStatement;
+
+import java.util.Collection;
+import java.util.LinkedList;
+import java.util.Optional;
+
+public final class TableExtractUtils {
+ /**
+ * Get table that should be rewrited from SelectStatement.
+ *
+ * @param selectStatement SelectStatement.
+ * @return SimpleTableSegment collection
+ */
+ public static Collection<SimpleTableSegment> getTablesFromSelect(final SelectStatement selectStatement) {
+ Collection<SimpleTableSegment> result = new LinkedList<>();
+ Collection<TableSegment> realTables = new LinkedList<>();
+ Collection<TableSegment> allTables = new LinkedList<>();
+ for (TableReferenceSegment each : selectStatement.getTableReferences()) {
+ allTables.addAll(getTablesFromTableReference(each));
+ realTables.addAll(getRealTablesFromTableReference(each));
+ }
+ if (selectStatement.getWhere().isPresent()) {
+ allTables.addAll(getAllTablesFromWhere(selectStatement.getWhere().get(), realTables));
+ }
+ result.addAll(getAllTablesFromProjections(selectStatement.getProjections(), realTables));
+ if (selectStatement.getGroupBy().isPresent()) {
+ result.addAll(getAllTablesFromOrderByItems(selectStatement.getGroupBy().get().getGroupByItems(), realTables));
+ }
+ if (selectStatement.getOrderBy().isPresent()) {
+ result.addAll(getAllTablesFromOrderByItems(selectStatement.getOrderBy().get().getOrderByItems(), realTables));
+ }
+ for (TableSegment each : allTables) {
+ if (each instanceof SubqueryTableSegment) {
+ result.addAll(getTablesFromSelect(((SubqueryTableSegment) each).getSubquery().getSelect()));
+ } else {
+ result.add((SimpleTableSegment) each);
+ }
+ }
+ return result;
+ }
+
+ /**
+ * Get real table that should be rewrited from SelectStatement.
+ *
+ * @param selectStatement SelectStatement.
+ * @return SimpleTableSegment collection
+ */
+ public static Collection<SimpleTableSegment> getSimpleTableFromSelect(final SelectStatement selectStatement) {
+ Collection<SimpleTableSegment> result = new LinkedList<>();
+ Collection<TableSegment> realTables = new LinkedList<>();
+ for (TableReferenceSegment each : selectStatement.getTableReferences()) {
+ realTables.addAll(getRealTablesFromTableReference(each));
+ }
+ for (TableSegment each : realTables) {
+ if (each instanceof SubqueryTableSegment) {
+ result.addAll(getSimpleTableFromSelect(((SubqueryTableSegment) each).getSubquery().getSelect()));
+ } else {
+ result.add((SimpleTableSegment) each);
+ }
+ }
+ return result;
+ }
+
+ private static Collection<TableSegment> getTablesFromTableReference(final TableReferenceSegment tableReferenceSegment) {
+ Collection<TableSegment> result = new LinkedList<>();
+ if (null != tableReferenceSegment.getTableFactor()) {
+ result.addAll(getTablesFromTableFactor(tableReferenceSegment.getTableFactor()));
+ }
+ if (null != tableReferenceSegment.getJoinedTables()) {
+ for (JoinedTableSegment each : tableReferenceSegment.getJoinedTables()) {
+ result.addAll(getTablesFromJoinTable(each, result));
+ }
+ }
+ return result;
+ }
+
+ private static Collection<TableSegment> getRealTablesFromTableReference(final TableReferenceSegment tableReferenceSegment) {
+ Collection<TableSegment> result = new LinkedList<>();
+ if (null != tableReferenceSegment.getTableFactor()) {
+ result.addAll(getRealTablesFromTableFactor(tableReferenceSegment.getTableFactor()));
+ }
+ if (null != tableReferenceSegment.getJoinedTables()) {
+ for (JoinedTableSegment each : tableReferenceSegment.getJoinedTables()) {
+ result.addAll(getRealTablesFromJoinTable(each));
+ }
+ }
+ return result;
+ }
+
+ private static Collection<TableSegment> getTablesFromTableFactor(final TableFactorSegment tableFactorSegment) {
+ Collection<TableSegment> result = new LinkedList<>();
+ if (null != tableFactorSegment.getTable() && tableFactorSegment.getTable() instanceof SimpleTableSegment) {
+ result.add(tableFactorSegment.getTable());
+ }
+ if (null != tableFactorSegment.getTable() && tableFactorSegment.getTable() instanceof SubqueryTableSegment) {
+ result.add(tableFactorSegment.getTable());
+ }
+ if (null != tableFactorSegment.getTableReferences() && !tableFactorSegment.getTableReferences().isEmpty()) {
+ for (TableReferenceSegment each: tableFactorSegment.getTableReferences()) {
+ result.addAll(getTablesFromTableReference(each));
+ }
+ }
+ return result;
+ }
+
+ private static Collection<TableSegment> getRealTablesFromTableFactor(final TableFactorSegment tableFactorSegment) {
+ Collection<TableSegment> result = new LinkedList<>();
+ if (null != tableFactorSegment.getTable() && tableFactorSegment.getTable() instanceof SimpleTableSegment) {
+ result.add(tableFactorSegment.getTable());
+ }
+ if (null != tableFactorSegment.getTable() && tableFactorSegment.getTable() instanceof SubqueryTableSegment) {
+ result.add(tableFactorSegment.getTable());
+ }
+ if (null != tableFactorSegment.getTableReferences() && !tableFactorSegment.getTableReferences().isEmpty()) {
+ for (TableReferenceSegment each: tableFactorSegment.getTableReferences()) {
+ result.addAll(getRealTablesFromTableReference(each));
+ }
+ }
+ return result;
+ }
+
+ private static Collection<TableSegment> getTablesFromJoinTable(final JoinedTableSegment joinedTableSegment, final Collection<TableSegment> tableSegments) {
+ Collection<TableSegment> result = new LinkedList<>();
+ Collection<TableSegment> realTables = new LinkedList<>();
+ realTables.addAll(tableSegments);
+ if (null != joinedTableSegment.getTableFactor()) {
+ result.addAll(getTablesFromTableFactor(joinedTableSegment.getTableFactor()));
+ realTables.addAll(getTablesFromTableFactor(joinedTableSegment.getTableFactor()));
+ }
+ if (null != joinedTableSegment.getJoinSpecification()) {
+ result.addAll(getTablesFromJoinSpecification(joinedTableSegment.getJoinSpecification(), realTables));
+ }
+ return result;
+ }
+
+ private static Collection<TableSegment> getRealTablesFromJoinTable(final JoinedTableSegment joinedTableSegment) {
+ Collection<TableSegment> result = new LinkedList<>();
+ if (null != joinedTableSegment.getTableFactor()) {
+ result.addAll(getTablesFromTableFactor(joinedTableSegment.getTableFactor()));
+ }
+ return result;
+ }
+
+ private static Collection<SimpleTableSegment> getTablesFromJoinSpecification(final JoinSpecificationSegment joinSpecificationSegment, final Collection<TableSegment> tableSegments) {
+ Collection<SimpleTableSegment> result = new LinkedList<>();
+ Collection<AndPredicate> andPredicates = joinSpecificationSegment.getAndPredicates();
+ for (AndPredicate each : andPredicates) {
+ for (PredicateSegment e : each.getPredicates()) {
+ if (null != e.getColumn() && (e.getColumn().getOwner().isPresent())) {
+ OwnerSegment ownerSegment = e.getColumn().getOwner().get();
+ if (isTable(ownerSegment, tableSegments)) {
+ result.add(new SimpleTableSegment(ownerSegment.getStartIndex(), ownerSegment.getStopIndex(), ownerSegment.getIdentifier()));
+ }
+ }
+ if (null != e.getRightValue() && (e.getRightValue() instanceof ColumnSegment) && ((ColumnSegment) e.getRightValue()).getOwner().isPresent()) {
+ OwnerSegment ownerSegment = ((ColumnSegment) e.getRightValue()).getOwner().get();
+ if (isTable(ownerSegment, tableSegments)) {
+ result.add(new SimpleTableSegment(ownerSegment.getStartIndex(), ownerSegment.getStopIndex(), ownerSegment.getIdentifier()));
+ }
+ }
+ }
+ }
+ return result;
+ }
+
+ private static Collection<SimpleTableSegment> getAllTablesFromWhere(final WhereSegment where, final Collection<TableSegment> tableSegments) {
+ Collection<SimpleTableSegment> result = new LinkedList<>();
+ for (AndPredicate each : where.getAndPredicates()) {
+ for (PredicateSegment predicate : each.getPredicates()) {
+ result.addAll(getAllTablesFromPredicate(predicate, tableSegments));
+ }
+ }
+ return result;
+ }
+
+ private static Collection<SimpleTableSegment> getAllTablesFromPredicate(final PredicateSegment predicate, final Collection<TableSegment> tableSegments) {
+ Collection<SimpleTableSegment> result = new LinkedList<>();
+ if (predicate.getColumn().getOwner().isPresent() && isTable(predicate.getColumn().getOwner().get(), tableSegments)) {
+ OwnerSegment segment = predicate.getColumn().getOwner().get();
+ result.add(new SimpleTableSegment(segment.getStartIndex(), segment.getStopIndex(), segment.getIdentifier()));
+ }
+ if (predicate.getRightValue() instanceof PredicateCompareRightValue) {
+ if (((PredicateCompareRightValue) predicate.getRightValue()).getExpression() instanceof SubqueryExpressionSegment) {
+ result.addAll(TableExtractUtils.getTablesFromSelect(((SubqueryExpressionSegment) ((PredicateCompareRightValue) predicate.getRightValue()).getExpression()).getSubquery().getSelect()));
+ }
+ } else {
+ if (predicate.getRightValue() instanceof ColumnSegment) {
+ Preconditions.checkState(((ColumnSegment) predicate.getRightValue()).getOwner().isPresent());
+ OwnerSegment segment = ((ColumnSegment) predicate.getRightValue()).getOwner().get();
+ result.add(new SimpleTableSegment(segment.getStartIndex(), segment.getStopIndex(), segment.getIdentifier()));
+ }
+ }
+ return result;
+ }
+
+ private static Collection<SimpleTableSegment> getAllTablesFromProjections(final ProjectionsSegment projections, final Collection<TableSegment> tableSegments) {
+ Collection<SimpleTableSegment> result = new LinkedList<>();
+ if (null == projections || projections.getProjections().isEmpty()) {
+ return result;
+ }
+ for (ProjectionSegment each : projections.getProjections()) {
+ if (each instanceof SubqueryProjectionSegment) {
+ result.addAll(getTablesFromSelect(((SubqueryProjectionSegment) each).getSubquery().getSelect()));
+ } else {
+ Optional<SimpleTableSegment> table = getTableSegment(each, tableSegments);
+ table.ifPresent(result::add);
+ }
+ }
+ return result;
+ }
+
+ private static Optional<SimpleTableSegment> getTableSegment(final ProjectionSegment each, final Collection<TableSegment> tableSegments) {
+ Optional<OwnerSegment> owner = getTableOwner(each);
+ if (owner.isPresent() && isTable(owner.get(), tableSegments)) {
+ return Optional .of(new SimpleTableSegment(owner.get().getStartIndex(), owner.get().getStopIndex(), owner.get().getIdentifier()));
+ }
+ return Optional.empty();
+ }
+
+ private static Optional<OwnerSegment> getTableOwner(final ProjectionSegment each) {
+ if (each instanceof OwnerAvailable) {
+ return ((OwnerAvailable) each).getOwner();
+ }
+ if (each instanceof ColumnProjectionSegment) {
+ return ((ColumnProjectionSegment) each).getColumn().getOwner();
+ }
+ return Optional.empty();
+ }
+
+ private static Collection<SimpleTableSegment> getAllTablesFromOrderByItems(final Collection<OrderByItemSegment> orderByItems, final Collection<TableSegment> tableSegments) {
+ Collection<SimpleTableSegment> result = new LinkedList<>();
+ for (OrderByItemSegment each : orderByItems) {
+ if (each instanceof ColumnOrderByItemSegment) {
+ Optional<OwnerSegment> owner = ((ColumnOrderByItemSegment) each).getColumn().getOwner();
+ if (owner.isPresent() && isTable(owner.get(), tableSegments)) {
+ Preconditions.checkState(((ColumnOrderByItemSegment) each).getColumn().getOwner().isPresent());
+ OwnerSegment segment = ((ColumnOrderByItemSegment) each).getColumn().getOwner().get();
+ result.add(new SimpleTableSegment(segment.getStartIndex(), segment.getStopIndex(), segment.getIdentifier()));
+ }
+ }
+ }
+ return result;
+ }
+
+ private static boolean isTable(final OwnerSegment owner, final Collection<TableSegment> tables) {
+ for (TableSegment each : tables) {
+ if (owner.getIdentifier().getValue().equals(each.getAlias().orElse(null))) {
+ return false;
+ }
+ }
+ return true;
+ }
+}
diff --git a/shardingsphere-sql-parser/shardingsphere-sql-parser-test/src/test/resources/sql/supported/ddl/create.xml b/shardingsphere-sql-parser/shardingsphere-sql-parser-test/src/test/resources/sql/supported/ddl/create.xml
index cfc71f0..c8ad26e 100644
--- a/shardingsphere-sql-parser/shardingsphere-sql-parser-test/src/test/resources/sql/supported/ddl/create.xml
+++ b/shardingsphere-sql-parser/shardingsphere-sql-parser-test/src/test/resources/sql/supported/ddl/create.xml
@@ -73,7 +73,7 @@
t_order (order_id) " db-types="MySQL,SQLServer" />
<sql-case id="create_index_with_back_quota" value="CREATE INDEX `order_index` ON `t_order` (`order_id`)" db-types="MySQL" />
<sql-case id="create_composite_index" value="CREATE INDEX order_index ON t_order (order_id, user_id, status)" />
- <sql-case id="create_btree_index" value="CREATE INDEX order_index ON t_order USING BTREE (order_id)" db-types="MySQL" />
+ <sql-case id="create_btree_index" value="CREATE INDEX order_index ON t_order USING BTREE (order_id)" db-types="PostgreSQL" />
<sql-case id="create_table_with_quota" value="CREATE TABLE "t_order" ("order_id" NUMBER(10), "user_id" NUMBER(10), "status" VARCHAR2(10), "column1" VARCHAR2(10), "column2" VARCHAR2(10), "column3" VARCHAR2(10))" db-types="Oracle" />
<sql-case id="create_table_with_column_on_null_default" value="CREATE TABLE t_order (order_id NUMBER(10) DEFAULT ON NULL 0, user_id NUMBER(10), status VARCHAR2(10), column1 VARCHAR2(10), column2 VARCHAR2(10), column3 VARCHAR2(10))" db-types="Oracle" />
<sql-case id="create_table_with_column_identity" value="CREATE TABLE t_order (order_id NUMBER(10) GENERATED BY DEFAULT AS IDENTITY START WITH 1 MAXVALUE 100, user_id NUMBER(10), status VARCHAR2(10), column1 VARCHAR2(10), column2 VARCHAR2(10), column3 VARCHAR2(10))" db-types="Oracle" />