You are viewing a plain text version of this content. The canonical link for it is here.
Posted to notifications@shardingsphere.apache.org by du...@apache.org on 2023/02/16 06:19:01 UTC

[shardingsphere] branch master updated: Fix union extract table error. (#24186)

This is an automated email from the ASF dual-hosted git repository.

duanzhengqiang pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/shardingsphere.git


The following commit(s) were added to refs/heads/master by this push:
     new 66e873cce76 Fix union extract table error. (#24186)
66e873cce76 is described below

commit 66e873cce766e2c7b0bd5d4fc1943651b315ec52
Author: Chuxin Chen <ch...@qq.com>
AuthorDate: Thu Feb 16 14:18:52 2023 +0800

    Fix union extract table error. (#24186)
---
 .../sql/common/extractor/TableExtractor.java       | 10 ++++----
 .../sql/common/extractor/TableExtractorTest.java   | 27 ++++++++++++++++++++++
 2 files changed, 32 insertions(+), 5 deletions(-)

diff --git a/sql-parser/statement/src/main/java/org/apache/shardingsphere/sql/parser/sql/common/extractor/TableExtractor.java b/sql-parser/statement/src/main/java/org/apache/shardingsphere/sql/parser/sql/common/extractor/TableExtractor.java
index 62082b3c077..f4706f8ad7a 100644
--- a/sql-parser/statement/src/main/java/org/apache/shardingsphere/sql/parser/sql/common/extractor/TableExtractor.java
+++ b/sql-parser/statement/src/main/java/org/apache/shardingsphere/sql/parser/sql/common/extractor/TableExtractor.java
@@ -74,6 +74,11 @@ public final class TableExtractor {
      * @param selectStatement select statement
      */
     public void extractTablesFromSelect(final SelectStatement selectStatement) {
+        if (selectStatement.getCombine().isPresent()) {
+            CombineSegment combineSegment = selectStatement.getCombine().get();
+            extractTablesFromSelect(combineSegment.getLeft());
+            extractTablesFromSelect(combineSegment.getRight());
+        }
         if (null != selectStatement.getFrom() && !selectStatement.getCombine().isPresent()) {
             extractTablesFromTableSegment(selectStatement.getFrom());
         }
@@ -91,11 +96,6 @@ public final class TableExtractor {
         }
         Optional<LockSegment> lockSegment = SelectStatementHandler.getLockSegment(selectStatement);
         lockSegment.ifPresent(this::extractTablesFromLock);
-        if (selectStatement.getCombine().isPresent()) {
-            CombineSegment combineSegment = selectStatement.getCombine().get();
-            extractTablesFromSelect(combineSegment.getLeft());
-            extractTablesFromSelect(combineSegment.getRight());
-        }
     }
     
     private void extractTablesFromTableSegment(final TableSegment tableSegment) {
diff --git a/sql-parser/statement/src/test/java/org/apache/shardingsphere/sql/parser/sql/common/extractor/TableExtractorTest.java b/sql-parser/statement/src/test/java/org/apache/shardingsphere/sql/parser/sql/common/extractor/TableExtractorTest.java
index 8af6ba6045d..0cc1991eee7 100644
--- a/sql-parser/statement/src/test/java/org/apache/shardingsphere/sql/parser/sql/common/extractor/TableExtractorTest.java
+++ b/sql-parser/statement/src/test/java/org/apache/shardingsphere/sql/parser/sql/common/extractor/TableExtractorTest.java
@@ -28,9 +28,11 @@ import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.column.OnDupl
 import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.combine.CombineSegment;
 import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.simple.LiteralExpressionSegment;
 import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.item.AggregationProjectionSegment;
+import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.item.ColumnProjectionSegment;
 import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.item.ProjectionsSegment;
 import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.item.ShorthandProjectionSegment;
 import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.predicate.LockSegment;
+import org.apache.shardingsphere.sql.parser.sql.common.segment.generic.AliasSegment;
 import org.apache.shardingsphere.sql.parser.sql.common.segment.generic.OwnerSegment;
 import org.apache.shardingsphere.sql.parser.sql.common.segment.generic.table.SimpleTableSegment;
 import org.apache.shardingsphere.sql.parser.sql.common.segment.generic.table.TableNameSegment;
@@ -143,4 +145,29 @@ public final class TableExtractorTest {
         result.setFrom(new SimpleTableSegment(new TableNameSegment(0, 0, new IdentifierValue(tableName))));
         return result;
     }
+    
+    @Test
+    public void assertExtractTablesFromCombineSegmentWithColumnProjection() {
+        MySQLSelectStatement selectStatement = createSelectStatementWithColumnProjection("t_order");
+        selectStatement.setCombine(new CombineSegment(0, 0, createSelectStatementWithColumnProjection("t_order"), CombineType.UNION, createSelectStatementWithColumnProjection("t_order_item")));
+        tableExtractor.extractTablesFromSelect(selectStatement);
+        Collection<SimpleTableSegment> actual = tableExtractor.getRewriteTables();
+        assertThat(actual.size(), is(2));
+        Iterator<SimpleTableSegment> iterator = actual.iterator();
+        assertTableSegment(iterator.next(), 0, 0, "t_order");
+        assertTableSegment(iterator.next(), 0, 0, "t_order_item");
+    }
+    
+    private MySQLSelectStatement createSelectStatementWithColumnProjection(final String tableName) {
+        MySQLSelectStatement result = new MySQLSelectStatement();
+        ProjectionsSegment projections = new ProjectionsSegment(0, 0);
+        ColumnSegment columnSegment = new ColumnSegment(0, 0, new IdentifierValue("id"));
+        columnSegment.setOwner(new OwnerSegment(0, 0, new IdentifierValue("a")));
+        projections.getProjections().add(new ColumnProjectionSegment(columnSegment));
+        result.setProjections(projections);
+        SimpleTableSegment tableSegment = new SimpleTableSegment(new TableNameSegment(0, 0, new IdentifierValue(tableName)));
+        tableSegment.setAlias(new AliasSegment(0, 0, new IdentifierValue("a")));
+        result.setFrom(tableSegment);
+        return result;
+    }
 }