You are viewing a plain text version of this content. The canonical link for it is here.
Posted to notifications@shardingsphere.apache.org by du...@apache.org on 2021/11/01 02:25:22 UTC

[shardingsphere] branch master updated: ShardingSphere-JDBC 5.0.0-RC1-SNAPSHOT: throw StringIndexOutOfBoundsException when config the encrypt rule and use the sql join (#13305)

This is an automated email from the ASF dual-hosted git repository.

duanzhengqiang pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/shardingsphere.git


The following commit(s) were added to refs/heads/master by this push:
     new e775f20  ShardingSphere-JDBC 5.0.0-RC1-SNAPSHOT: throw StringIndexOutOfBoundsException when config the encrypt rule and use the sql join (#13305)
e775f20 is described below

commit e775f20047f0ac5c31a70bc0a7a73d8886f3b68a
Author: liguoping <xd...@163.com>
AuthorDate: Mon Nov 1 10:24:23 2021 +0800

    ShardingSphere-JDBC 5.0.0-RC1-SNAPSHOT: throw StringIndexOutOfBoundsException when config the encrypt rule and use the sql join (#13305)
    
    * SQLToken only add table self
    
    * check style
    
    * solute alias,owner,ambiguous for encrypt rule
    
    * java doc
    
    * code style
    
    * delete sout
    
    * move getAllUniqueTables method TableContext
    
    * use int instead of atomicInteger
    
    * extract isOwnerExistsMatchTableAlias, isOwnerExistsMatchTableName and isColumnAmbiguous to one method
    
    * ut
    
    * ut match table name
    
    * check style final,hh
    
    * isColumnUnAmbiguous
    
    * rewrite test
---
 .../impl/EncryptProjectionTokenGenerator.java      |  57 +++++++--
 .../impl/EncryptProjectionTokenGeneratorTest.java  | 142 +++++++++++++++++++++
 .../infra/binder/segment/table/TablesContext.java  |   9 ++
 .../encrypt/case/select_for_query_with_plain.xml   |  10 ++
 4 files changed, 209 insertions(+), 9 deletions(-)

diff --git a/shardingsphere-features/shardingsphere-encrypt/shardingsphere-encrypt-core/src/main/java/org/apache/shardingsphere/encrypt/rewrite/token/generator/impl/EncryptProjectionTokenGenerator.java b/shardingsphere-features/shardingsphere-encrypt/shardingsphere-encrypt-core/src/main/java/org/apache/shardingsphere/encrypt/rewrite/token/generator/impl/EncryptProjectionTokenGenerator.java
index 436093c..b43d874 100644
--- a/shardingsphere-features/shardingsphere-encrypt/shardingsphere-encrypt-core/src/main/java/org/apache/shardingsphere/encrypt/rewrite/token/generator/impl/EncryptProjectionTokenGenerator.java
+++ b/shardingsphere-features/shardingsphere-encrypt/shardingsphere-encrypt-core/src/main/java/org/apache/shardingsphere/encrypt/rewrite/token/generator/impl/EncryptProjectionTokenGenerator.java
@@ -17,13 +17,8 @@
 
 package org.apache.shardingsphere.encrypt.rewrite.token.generator.impl;
 
-import java.util.Collection;
-import java.util.Collections;
-import java.util.LinkedHashSet;
-import java.util.LinkedList;
-import java.util.List;
-import java.util.Optional;
-
+import com.google.common.base.Preconditions;
+import lombok.Setter;
 import org.apache.shardingsphere.encrypt.rewrite.aware.QueryWithCipherColumnAware;
 import org.apache.shardingsphere.encrypt.rewrite.token.generator.BaseEncryptSQLTokenGenerator;
 import org.apache.shardingsphere.encrypt.rule.EncryptTable;
@@ -45,7 +40,12 @@ import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.item.Projecti
 import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.item.ShorthandProjectionSegment;
 import org.apache.shardingsphere.sql.parser.sql.common.segment.generic.OwnerSegment;
 
-import lombok.Setter;
+import java.util.Collection;
+import java.util.Collections;
+import java.util.LinkedHashSet;
+import java.util.LinkedList;
+import java.util.List;
+import java.util.Optional;
 
 /**
  * Projection token generator for encrypt.
@@ -90,7 +90,8 @@ public final class EncryptProjectionTokenGenerator extends BaseEncryptSQLTokenGe
         Collection<SubstitutableColumnNameToken> result = new LinkedList<>();
         for (ProjectionSegment each : segment.getProjections()) {
             if (each instanceof ColumnProjectionSegment) {
-                if (encryptTable.getLogicColumns().contains(((ColumnProjectionSegment) each).getColumn().getIdentifier().getValue())) {
+                if (encryptTable.getLogicColumns().contains(((ColumnProjectionSegment) each).getColumn().getIdentifier().getValue()) 
+                        && columnMatchTableAndCheckAmbiguous(selectStatementContext, (ColumnProjectionSegment) each, tableName)) {
                     result.add(generateSQLToken((ColumnProjectionSegment) each, tableName, insertSelect));
                 }
             }
@@ -104,6 +105,44 @@ public final class EncryptProjectionTokenGenerator extends BaseEncryptSQLTokenGe
         return result;
     }
     
+    private boolean columnMatchTableAndCheckAmbiguous(final SelectStatementContext selectStatementContext, final ColumnProjectionSegment columnProjectionSegment, final String tableName) {
+        return isOwnerExistsMatchTableAlias(selectStatementContext, columnProjectionSegment, tableName) 
+                || isOwnerExistsMatchTableName(selectStatementContext, columnProjectionSegment, tableName) 
+                || isColumnUnAmbiguous(selectStatementContext, columnProjectionSegment);
+    }
+    
+    private boolean isOwnerExistsMatchTableAlias(final SelectStatementContext selectStatementContext, final ColumnProjectionSegment columnProjectionSegment, final String tableName) {
+        if (!columnProjectionSegment.getColumn().getOwner().isPresent()) {
+            return false;
+        }
+        return selectStatementContext.getTablesContext().getAllUniqueTables().stream().anyMatch(table -> tableName.equals(table.getTableName().getIdentifier().getValue()) 
+                && table.getAlias().isPresent() && columnProjectionSegment.getColumn().getOwner().get().getIdentifier().getValue().equals(table.getAlias().get()));
+    }
+    
+    private boolean isOwnerExistsMatchTableName(final SelectStatementContext selectStatementContext, final ColumnProjectionSegment columnProjectionSegment, final String tableName) {
+        if (!columnProjectionSegment.getColumn().getOwner().isPresent()) {
+            return false;
+        }
+        return selectStatementContext.getTablesContext().getAllUniqueTables().stream().anyMatch(table -> tableName.equals(table.getTableName().getIdentifier().getValue()) 
+                && !table.getAlias().isPresent() && columnProjectionSegment.getColumn().getOwner().get().getIdentifier().getValue().equals(tableName));
+    }
+    
+    private boolean isColumnUnAmbiguous(final SelectStatementContext selectStatementContext, final ColumnProjectionSegment columnProjectionSegment) {
+        if (columnProjectionSegment.getColumn().getOwner().isPresent()) {
+            return false;
+        }
+        int columnCount = 0;
+        for (String each : selectStatementContext.getTablesContext().getTableNames()) {
+            Optional<EncryptTable> encryptTable;
+            if ((encryptTable = getEncryptRule().findEncryptTable(each)).isPresent() 
+                    && encryptTable.get().getLogicColumns().contains(columnProjectionSegment.getColumn().getIdentifier().getValue())) {
+                columnCount++;
+            }
+        }
+        Preconditions.checkState(columnCount <= 1, "column `%s` is ambiguous in encrypt rules", columnProjectionSegment.getColumn().getIdentifier().getValue());
+        return true;
+    }
+    
     private boolean isToGeneratedSQLToken(final ProjectionSegment projectionSegment, final SelectStatementContext selectStatementContext, final String tableName) {
         if (!(projectionSegment instanceof ShorthandProjectionSegment)) {
             return false;
diff --git a/shardingsphere-features/shardingsphere-encrypt/shardingsphere-encrypt-core/src/test/java/org/apache/shardingsphere/encrypt/rewrite/impl/EncryptProjectionTokenGeneratorTest.java b/shardingsphere-features/shardingsphere-encrypt/shardingsphere-encrypt-core/src/test/java/org/apache/shardingsphere/encrypt/rewrite/impl/EncryptProjectionTokenGeneratorTest.java
new file mode 100644
index 0000000..a555ea0
--- /dev/null
+++ b/shardingsphere-features/shardingsphere-encrypt/shardingsphere-encrypt-core/src/test/java/org/apache/shardingsphere/encrypt/rewrite/impl/EncryptProjectionTokenGeneratorTest.java
@@ -0,0 +1,142 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.shardingsphere.encrypt.rewrite.impl;
+
+import org.apache.shardingsphere.encrypt.rewrite.token.generator.impl.EncryptProjectionTokenGenerator;
+import org.apache.shardingsphere.encrypt.rule.EncryptRule;
+import org.apache.shardingsphere.encrypt.rule.EncryptTable;
+import org.apache.shardingsphere.infra.binder.statement.dml.SelectStatementContext;
+import org.apache.shardingsphere.infra.rewrite.sql.token.pojo.generic.SubstitutableColumnNameToken;
+import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.column.ColumnSegment;
+import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.item.ColumnProjectionSegment;
+import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.item.ProjectionsSegment;
+import org.apache.shardingsphere.sql.parser.sql.common.segment.generic.OwnerSegment;
+import org.apache.shardingsphere.sql.parser.sql.common.segment.generic.table.SimpleTableSegment;
+import org.apache.shardingsphere.sql.parser.sql.common.value.identifier.IdentifierValue;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.ExpectedException;
+
+import java.util.Arrays;
+import java.util.Collection;
+import java.util.Collections;
+import java.util.List;
+import java.util.Optional;
+
+import static org.hamcrest.CoreMatchers.is;
+import static org.junit.Assert.assertThat;
+import static org.mockito.Mockito.RETURNS_DEEP_STUBS;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.when;
+
+public final class EncryptProjectionTokenGeneratorTest {
+    
+    @Rule
+    public ExpectedException expectedException = ExpectedException.none();
+    
+    private EncryptProjectionTokenGenerator encryptProjectionTokenGenerator;
+    
+    @Before
+    public void setup() {
+        encryptProjectionTokenGenerator = new EncryptProjectionTokenGenerator();
+        encryptProjectionTokenGenerator.setEncryptRule(buildEncryptRule());
+    }
+    
+    @Test
+    public void assertOwnerExistsMatchTableAliasGenerateSQLTokens() {
+        ProjectionsSegment projectionsSegment = mock(ProjectionsSegment.class);
+        SelectStatementContext sqlStatementContext = mock(SelectStatementContext.class, RETURNS_DEEP_STUBS);
+        when(sqlStatementContext.getSqlStatement().getProjections()).thenReturn(projectionsSegment);
+        when(sqlStatementContext.getTablesContext().getTableNames()).thenReturn(Arrays.asList("doctor", "doctor1"));
+        List<SimpleTableSegment> allUniqueTables = buildAllUniqueTables();
+        when(sqlStatementContext.getTablesContext().getAllUniqueTables()).thenReturn(allUniqueTables);
+        IdentifierValue identifierValue = new IdentifierValue("mobile");
+        ColumnSegment columnSegment = new ColumnSegment(0, 0, identifierValue);
+        OwnerSegment ownerSegment = new OwnerSegment(0, 0, new IdentifierValue("a"));
+        columnSegment.setOwner(ownerSegment);
+        ColumnProjectionSegment columnProjectionSegment = new ColumnProjectionSegment(columnSegment);
+        when(projectionsSegment.getProjections()).thenReturn(Collections.singletonList(columnProjectionSegment));
+        Collection<SubstitutableColumnNameToken> tokens = encryptProjectionTokenGenerator.generateSQLTokens(sqlStatementContext);
+        assertThat(tokens.size(), is(1));
+    }
+    
+    @Test
+    public void assertOwnerExistsMatchTableNameGenerateSQLTokens() {
+        ProjectionsSegment projectionsSegment = mock(ProjectionsSegment.class);
+        SelectStatementContext sqlStatementContext = mock(SelectStatementContext.class, RETURNS_DEEP_STUBS);
+        when(sqlStatementContext.getSqlStatement().getProjections()).thenReturn(projectionsSegment);
+        when(sqlStatementContext.getTablesContext().getTableNames()).thenReturn(Arrays.asList("doctor", "doctor1"));
+        List<SimpleTableSegment> allUniqueTables = buildAllUniqueTables(false);
+        when(sqlStatementContext.getTablesContext().getAllUniqueTables()).thenReturn(allUniqueTables);
+        IdentifierValue identifierValue = new IdentifierValue("mobile");
+        ColumnSegment columnSegment = new ColumnSegment(0, 0, identifierValue);
+        OwnerSegment ownerSegment = new OwnerSegment(0, 0, new IdentifierValue("doctor"));
+        columnSegment.setOwner(ownerSegment);
+        ColumnProjectionSegment columnProjectionSegment = new ColumnProjectionSegment(columnSegment);
+        when(projectionsSegment.getProjections()).thenReturn(Collections.singletonList(columnProjectionSegment));
+        Collection<SubstitutableColumnNameToken> tokens = encryptProjectionTokenGenerator.generateSQLTokens(sqlStatementContext);
+        assertThat(tokens.size(), is(1));
+    }
+    
+    @Test
+    public void assertColumnUnAmbiguousGenerateSQLTokens() {
+        expectedException.expect(IllegalStateException.class);
+        expectedException.expectMessage("column `mobile` is ambiguous in encrypt rules");
+        ProjectionsSegment projectionsSegment = mock(ProjectionsSegment.class);
+        SelectStatementContext sqlStatementContext = mock(SelectStatementContext.class, RETURNS_DEEP_STUBS);
+        when(sqlStatementContext.getSqlStatement().getProjections()).thenReturn(projectionsSegment);
+        when(sqlStatementContext.getTablesContext().getTableNames()).thenReturn(Arrays.asList("doctor", "doctor1"));
+        List<SimpleTableSegment> allUniqueTables = buildAllUniqueTables();
+        when(sqlStatementContext.getTablesContext().getAllUniqueTables()).thenReturn(allUniqueTables);
+        IdentifierValue identifierValue = new IdentifierValue("mobile");
+        ColumnSegment columnSegment = new ColumnSegment(0, 0, identifierValue);
+        ColumnProjectionSegment columnProjectionSegment = new ColumnProjectionSegment(columnSegment);
+        when(projectionsSegment.getProjections()).thenReturn(Collections.singletonList(columnProjectionSegment));
+        encryptProjectionTokenGenerator.generateSQLTokens(sqlStatementContext);
+    }
+    
+    private List<SimpleTableSegment> buildAllUniqueTables() {
+        return buildAllUniqueTables(true);
+    }
+    
+    private List<SimpleTableSegment> buildAllUniqueTables(final boolean hasAlias) {
+        SimpleTableSegment table1 = mock(SimpleTableSegment.class, RETURNS_DEEP_STUBS);
+        when(table1.getTableName().getIdentifier().getValue()).thenReturn("doctor");
+        SimpleTableSegment table2 = mock(SimpleTableSegment.class, RETURNS_DEEP_STUBS);
+        when(table2.getTableName().getIdentifier().getValue()).thenReturn("doctor1");
+        if (hasAlias) {
+            when(table1.getAlias()).thenReturn(Optional.of("a"));
+            when(table2.getAlias()).thenReturn(Optional.of("b"));
+        }
+        return Arrays.asList(table1, table2);
+    }
+    
+    private EncryptRule buildEncryptRule() {
+        EncryptRule encryptRule = mock(EncryptRule.class);
+        EncryptTable encryptTable1 = mock(EncryptTable.class);
+        EncryptTable encryptTable2 = mock(EncryptTable.class);
+        when(encryptTable1.getLogicColumns()).thenReturn(Collections.singletonList("mobile"));
+        when(encryptTable2.getLogicColumns()).thenReturn(Collections.singletonList("mobile"));
+        when(encryptRule.findPlainColumn("doctor", "mobile")).thenReturn(Optional.of("mobile"));
+        when(encryptRule.findPlainColumn("doctor1", "mobile")).thenReturn(Optional.of("Mobile"));
+        when(encryptRule.findEncryptTable("doctor")).thenReturn(Optional.of(encryptTable1));
+        when(encryptRule.findEncryptTable("doctor1")).thenReturn(Optional.of(encryptTable2));
+        return encryptRule;
+    }
+}
diff --git a/shardingsphere-infra/shardingsphere-infra-binder/src/main/java/org/apache/shardingsphere/infra/binder/segment/table/TablesContext.java b/shardingsphere-infra/shardingsphere-infra-binder/src/main/java/org/apache/shardingsphere/infra/binder/segment/table/TablesContext.java
index 50237de..7c9e088 100644
--- a/shardingsphere-infra/shardingsphere-infra-binder/src/main/java/org/apache/shardingsphere/infra/binder/segment/table/TablesContext.java
+++ b/shardingsphere-infra/shardingsphere-infra-binder/src/main/java/org/apache/shardingsphere/infra/binder/segment/table/TablesContext.java
@@ -74,6 +74,15 @@ public final class TablesContext {
     }
     
     /**
+     * Get all unique table segments.
+     *
+     * @return all unique table segments
+     */
+    public Collection<SimpleTableSegment> getAllUniqueTables() {
+        return uniqueTables.values();
+    }
+    
+    /**
      * Find table name.
      *
      * @param columns column segment collection
diff --git a/shardingsphere-test/shardingsphere-rewrite-test/src/test/resources/scenario/encrypt/case/select_for_query_with_plain.xml b/shardingsphere-test/shardingsphere-rewrite-test/src/test/resources/scenario/encrypt/case/select_for_query_with_plain.xml
index 2739ccd..385b3e3 100644
--- a/shardingsphere-test/shardingsphere-rewrite-test/src/test/resources/scenario/encrypt/case/select_for_query_with_plain.xml
+++ b/shardingsphere-test/shardingsphere-rewrite-test/src/test/resources/scenario/encrypt/case/select_for_query_with_plain.xml
@@ -41,4 +41,14 @@
         <input sql="SELECT a.*, account_id, 1+1 FROM t_account_bak a" />
         <output sql="SELECT `a`.`account_id`, `a`.`cipher_certificate_number` AS certificate_number, `a`.`plain_password` AS password, `a`.`plain_amount` AS amount, `a`.`status`, account_id, 1+1 FROM t_account_bak a" />
     </rewrite-assertion>
+    
+    <rewrite-assertion id="select_join_with_alias">
+        <input sql="SELECT a.password from t_account a, t_account_bak b where a.account_id = b.account_id"/>
+        <output sql="SELECT a.cipher_password AS password from t_account a, t_account_bak b where a.account_id = b.account_id"/>
+    </rewrite-assertion>
+    
+    <rewrite-assertion id="select_join_with_table_name">
+        <input sql="SELECT t_account.password from t_account, t_account_bak where t_account.account_id = t_account_bak.account_id"/>
+        <output sql="SELECT t_account.cipher_password AS password from t_account, t_account_bak where t_account.account_id = t_account_bak.account_id"/>
+    </rewrite-assertion>
 </rewrite-assertions>