You are viewing a plain text version of this content. The canonical link for it is here.
Posted to notifications@shardingsphere.apache.org by du...@apache.org on 2022/05/20 01:42:22 UTC

[shardingsphere] branch master updated: Fix the sharding table skip OnDuplicatedKeyUpdate (#17756)

This is an automated email from the ASF dual-hosted git repository.

duanzhengqiang pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/shardingsphere.git


The following commit(s) were added to refs/heads/master by this push:
     new 50ad6cf68ce Fix the sharding table skip OnDuplicatedKeyUpdate (#17756)
50ad6cf68ce is described below

commit 50ad6cf68ce1b11c01644c6d0f86d036bf9a60ae
Author: cheese8 <yi...@163.com>
AuthorDate: Fri May 20 09:42:15 2022 +0800

    Fix the sharding table skip OnDuplicatedKeyUpdate (#17756)
    
    * Fix the sharding exception on OnDuplicatedKey
    
    * Update TableExtractor.java
    
    * align on review
    
    * Update TableExtractorTest.java
    
    * Update TableExtractorTest.java
    
    * Update insert_column.xml
    
    * Update TableExtractor.java
---
 .../sql/common/extractor/TableExtractor.java       | 16 +++++++++++++
 .../sql/common/extractor/TableExtractorTest.java   | 26 ++++++++++++++++++++++
 .../query-with-cipher/dml/insert/insert_column.xml |  5 +++++
 3 files changed, 47 insertions(+)

diff --git a/shardingsphere-sql-parser/shardingsphere-sql-parser-statement/src/main/java/org/apache/shardingsphere/sql/parser/sql/common/extractor/TableExtractor.java b/shardingsphere-sql-parser/shardingsphere-sql-parser-statement/src/main/java/org/apache/shardingsphere/sql/parser/sql/common/extractor/TableExtractor.java
index 1d374a6c43d..65317a43d34 100644
--- a/shardingsphere-sql-parser/shardingsphere-sql-parser-statement/src/main/java/org/apache/shardingsphere/sql/parser/sql/common/extractor/TableExtractor.java
+++ b/shardingsphere-sql-parser/shardingsphere-sql-parser-statement/src/main/java/org/apache/shardingsphere/sql/parser/sql/common/extractor/TableExtractor.java
@@ -20,6 +20,7 @@ package org.apache.shardingsphere.sql.parser.sql.common.extractor;
 import lombok.Getter;
 import org.apache.shardingsphere.sql.parser.sql.common.segment.ddl.routine.RoutineBodySegment;
 import org.apache.shardingsphere.sql.parser.sql.common.segment.ddl.routine.ValidStatementSegment;
+import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.assignment.AssignmentSegment;
 import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.column.ColumnSegment;
 import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.BetweenExpression;
 import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.BinaryOperationExpression;
@@ -50,6 +51,7 @@ import org.apache.shardingsphere.sql.parser.sql.common.statement.dml.InsertState
 import org.apache.shardingsphere.sql.parser.sql.common.statement.dml.SelectStatement;
 import org.apache.shardingsphere.sql.parser.sql.common.statement.dml.UpdateStatement;
 import org.apache.shardingsphere.sql.parser.sql.dialect.handler.ddl.CreateTableStatementHandler;
+import org.apache.shardingsphere.sql.parser.sql.dialect.handler.dml.InsertStatementHandler;
 import org.apache.shardingsphere.sql.parser.sql.dialect.handler.dml.SelectStatementHandler;
 
 import java.util.Collection;
@@ -217,11 +219,25 @@ public final class TableExtractor {
                 extractTablesFromExpression(each);
             }
         }
+        InsertStatementHandler.getOnDuplicateKeyColumnsSegment(insertStatement).ifPresent(each -> extractTablesFromAssignmentItems(each.getColumns()));
         if (insertStatement.getInsertSelect().isPresent()) {
             extractTablesFromSelect(insertStatement.getInsertSelect().get().getSelect());
         }
     }
     
+    private void extractTablesFromAssignmentItems(final Collection<AssignmentSegment> assignmentItems) {
+        assignmentItems.forEach(each -> extractTablesFromColumnSegments(each.getColumns()));
+    }
+    
+    private void extractTablesFromColumnSegments(final Collection<ColumnSegment> columnSegments) {
+        columnSegments.forEach(each -> {
+            if (each.getOwner().isPresent() && needRewrite(each.getOwner().get())) {
+                OwnerSegment ownerSegment = each.getOwner().get();
+                rewriteTables.add(new SimpleTableSegment(new TableNameSegment(ownerSegment.getStartIndex(), ownerSegment.getStopIndex(), ownerSegment.getIdentifier())));
+            }
+        });
+    }
+    
     /**
      * Extract table that should be rewrite from update statement.
      *
diff --git a/shardingsphere-sql-parser/shardingsphere-sql-parser-statement/src/test/java/org/apache/shardingsphere/sql/parser/sql/common/extractor/TableExtractorTest.java b/shardingsphere-sql-parser/shardingsphere-sql-parser-statement/src/test/java/org/apache/shardingsphere/sql/parser/sql/common/extractor/TableExtractorTest.java
index 21c7679a938..421863cb674 100644
--- a/shardingsphere-sql-parser/shardingsphere-sql-parser-statement/src/test/java/org/apache/shardingsphere/sql/parser/sql/common/extractor/TableExtractorTest.java
+++ b/shardingsphere-sql-parser/shardingsphere-sql-parser-statement/src/test/java/org/apache/shardingsphere/sql/parser/sql/common/extractor/TableExtractorTest.java
@@ -17,13 +17,23 @@
 
 package org.apache.shardingsphere.sql.parser.sql.common.extractor;
 
+import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.assignment.AssignmentSegment;
+import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.assignment.ColumnAssignmentSegment;
+import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.column.ColumnSegment;
+import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.column.OnDuplicateKeyColumnsSegment;
+import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.simple.LiteralExpressionSegment;
 import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.predicate.LockSegment;
+import org.apache.shardingsphere.sql.parser.sql.common.segment.generic.OwnerSegment;
 import org.apache.shardingsphere.sql.parser.sql.common.segment.generic.table.SimpleTableSegment;
 import org.apache.shardingsphere.sql.parser.sql.common.segment.generic.table.TableNameSegment;
 import org.apache.shardingsphere.sql.parser.sql.common.value.identifier.IdentifierValue;
+import org.apache.shardingsphere.sql.parser.sql.dialect.statement.mysql.dml.MySQLInsertStatement;
 import org.apache.shardingsphere.sql.parser.sql.dialect.statement.mysql.dml.MySQLSelectStatement;
 import org.junit.Test;
 
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collection;
 import java.util.Iterator;
 import java.util.Optional;
 
@@ -56,6 +66,22 @@ public final class TableExtractorTest {
         assertTableSegment(tableSegmentIterator.next(), 143, 154, "t_order_item");
     }
     
+    @Test
+    public void assertExtractTablesFromInsert() {
+        MySQLInsertStatement mySQLInsertStatement = new MySQLInsertStatement();
+        mySQLInsertStatement.setTable(new SimpleTableSegment(new TableNameSegment(122, 128, new IdentifierValue("t_order"))));
+        Collection<AssignmentSegment> assignmentSegments = new ArrayList<>();
+        ColumnSegment columnSegment = new ColumnSegment(133, 136, new IdentifierValue("id"));
+        columnSegment.setOwner(new OwnerSegment(130, 132, new IdentifierValue("t_order")));
+        assignmentSegments.add(new ColumnAssignmentSegment(130, 140, Arrays.asList(columnSegment), new LiteralExpressionSegment(141, 142, 1)));
+        mySQLInsertStatement.setOnDuplicateKeyColumns(new OnDuplicateKeyColumnsSegment(130, 140, assignmentSegments));
+        tableExtractor.extractTablesFromInsert(mySQLInsertStatement);
+        assertThat(tableExtractor.getRewriteTables().size(), is(2));
+        Iterator<SimpleTableSegment> tableSegmentIterator = tableExtractor.getRewriteTables().iterator();
+        assertTableSegment(tableSegmentIterator.next(), 122, 128, "t_order");
+        assertTableSegment(tableSegmentIterator.next(), 130, 132, "t_order");
+    }
+    
     private void assertTableSegment(final SimpleTableSegment actual, final int expectedStartIndex, final int expectedStopIndex, final String expectedTableName) {
         assertThat(actual.getStartIndex(), is(expectedStartIndex));
         assertThat(actual.getStopIndex(), is(expectedStopIndex));
diff --git a/shardingsphere-test/shardingsphere-rewrite-test/src/test/resources/scenario/mix/case/query-with-cipher/dml/insert/insert_column.xml b/shardingsphere-test/shardingsphere-rewrite-test/src/test/resources/scenario/mix/case/query-with-cipher/dml/insert/insert_column.xml
index 13e3ecdf960..a37516fbd62 100644
--- a/shardingsphere-test/shardingsphere-rewrite-test/src/test/resources/scenario/mix/case/query-with-cipher/dml/insert/insert_column.xml
+++ b/shardingsphere-test/shardingsphere-rewrite-test/src/test/resources/scenario/mix/case/query-with-cipher/dml/insert/insert_column.xml
@@ -38,6 +38,11 @@
         <input sql="INSERT INTO t_account(password, amount, status) VALUES ('aaa', 1000, 'OK'), ('bbb', 2000, 'OK'), ('ccc', 3000, 'OK'), ('ddd', 4000, 'OK')" />
         <output sql="INSERT INTO t_account_1(cipher_password, assisted_query_password, cipher_amount, status, account_id) VALUES ('encrypt_aaa', 'assisted_query_aaa', 'encrypt_1000', 'OK', 1), ('encrypt_bbb', 'assisted_query_bbb', 'encrypt_2000', 'OK', 1), ('encrypt_ccc', 'assisted_query_ccc', 'encrypt_3000', 'OK', 1), ('encrypt_ddd', 'assisted_query_ddd', 'encrypt_4000', 'OK', 1)" />
     </rewrite-assertion>
+
+    <rewrite-assertion id="insert_values_with_columns_without_id_for_literals_on_duplicated" db-types="MySQL">
+        <input sql="INSERT INTO t_account(password, amount, status) VALUES ('aaa', 1000, 'OK') ON DUPLICATE KEY UPDATE t_account.status='OK'" />
+        <output sql="INSERT INTO t_account_1(cipher_password, assisted_query_password, cipher_amount, status, account_id) VALUES ('encrypt_aaa', 'assisted_query_aaa', 'encrypt_1000', 'OK', 1) ON DUPLICATE KEY UPDATE t_account_1.status='OK'" />
+    </rewrite-assertion>
     
     <rewrite-assertion id="insert_values_with_columns_with_plain_with_id_for_parameters" db-types="MySQL">
         <input sql="INSERT INTO t_account_bak(account_id, password, amount, status) VALUES (?, ?, ?, ?), (2, 'bbb', 2000, 'OK'), (?, ?, ?, ?), (4, 'ddd', 4000, 'OK')" parameters="1, aaa, 1000, OK, 3, ccc, 3000, OK" />