You are viewing a plain text version of this content. The canonical link for it is here.
Posted to notifications@shardingsphere.apache.org by du...@apache.org on 2023/03/13 03:45:26 UTC

[shardingsphere] branch master updated: Encrypt like supports concat function (#24502)

This is an automated email from the ASF dual-hosted git repository.

duanzhengqiang pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/shardingsphere.git


The following commit(s) were added to refs/heads/master by this push:
     new cb140a26bff Encrypt like supports concat function (#24502)
cb140a26bff is described below

commit cb140a26bff185111c25bd528687824cd56cf8ac
Author: gxxiong <xi...@foxmail.com>
AuthorDate: Mon Mar 13 11:45:18 2023 +0800

    Encrypt like supports concat function (#24502)
    
    * Encrypt like supports concat function
    
    * add e2e test
---
 .../condition/impl/EncryptBinaryCondition.java     | 17 +++++
 .../EncryptPredicateRightValueTokenGenerator.java  | 10 +++
 .../EncryptPredicateFunctionRightValueToken.java   | 82 ++++++++++++++++++++++
 ...ncryptPredicateFunctionRightValueTokenTest.java | 43 ++++++++++++
 .../cases/dql/dql-integration-select-aggregate.xml |  4 ++
 .../query-with-cipher/dml/select/select-where.xml  |  5 ++
 .../query-with-plain/dml/select/select-where.xml   |  5 ++
 7 files changed, 166 insertions(+)

diff --git a/features/encrypt/core/src/main/java/org/apache/shardingsphere/encrypt/rewrite/condition/impl/EncryptBinaryCondition.java b/features/encrypt/core/src/main/java/org/apache/shardingsphere/encrypt/rewrite/condition/impl/EncryptBinaryCondition.java
index 6f535243a1e..3c78e95cdd9 100644
--- a/features/encrypt/core/src/main/java/org/apache/shardingsphere/encrypt/rewrite/condition/impl/EncryptBinaryCondition.java
+++ b/features/encrypt/core/src/main/java/org/apache/shardingsphere/encrypt/rewrite/condition/impl/EncryptBinaryCondition.java
@@ -22,10 +22,13 @@ import lombok.Getter;
 import lombok.ToString;
 import org.apache.shardingsphere.encrypt.rewrite.condition.EncryptCondition;
 import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.ExpressionSegment;
+import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.FunctionSegment;
 import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.simple.LiteralExpressionSegment;
 import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.simple.ParameterMarkerExpressionSegment;
 
 import java.util.ArrayList;
+import java.util.Collection;
+import java.util.Iterator;
 import java.util.LinkedHashMap;
 import java.util.List;
 import java.util.Map;
@@ -49,6 +52,8 @@ public final class EncryptBinaryCondition implements EncryptCondition {
     
     private final int stopIndex;
     
+    private final ExpressionSegment expressionSegment;
+    
     private final Map<Integer, Integer> positionIndexMap = new LinkedHashMap<>();
     
     private final Map<Integer, Object> positionValueMap = new LinkedHashMap<>();
@@ -59,6 +64,7 @@ public final class EncryptBinaryCondition implements EncryptCondition {
         this.operator = operator;
         this.startIndex = startIndex;
         this.stopIndex = stopIndex;
+        this.expressionSegment = expressionSegment;
         putPositionMap(expressionSegment);
     }
     
@@ -67,6 +73,17 @@ public final class EncryptBinaryCondition implements EncryptCondition {
             positionIndexMap.put(0, ((ParameterMarkerExpressionSegment) expressionSegment).getParameterMarkerIndex());
         } else if (expressionSegment instanceof LiteralExpressionSegment) {
             positionValueMap.put(0, ((LiteralExpressionSegment) expressionSegment).getLiterals());
+        } else if (expressionSegment instanceof FunctionSegment) {
+            Collection<ExpressionSegment> parameters = ((FunctionSegment) expressionSegment).getParameters();
+            Iterator<ExpressionSegment> iterator = parameters.iterator();
+            int i = 0;
+            while (iterator.hasNext()) {
+                ExpressionSegment next = iterator.next();
+                if (next instanceof LiteralExpressionSegment) {
+                    positionValueMap.put(i, ((LiteralExpressionSegment) next).getLiterals());
+                }
+                i++;
+            }
         }
     }
     
diff --git a/features/encrypt/core/src/main/java/org/apache/shardingsphere/encrypt/rewrite/token/generator/EncryptPredicateRightValueTokenGenerator.java b/features/encrypt/core/src/main/java/org/apache/shardingsphere/encrypt/rewrite/token/generator/EncryptPredicateRightValueTokenGenerator.java
index 967aa207fd4..ce97178e824 100644
--- a/features/encrypt/core/src/main/java/org/apache/shardingsphere/encrypt/rewrite/token/generator/EncryptPredicateRightValueTokenGenerator.java
+++ b/features/encrypt/core/src/main/java/org/apache/shardingsphere/encrypt/rewrite/token/generator/EncryptPredicateRightValueTokenGenerator.java
@@ -26,6 +26,7 @@ import org.apache.shardingsphere.encrypt.rewrite.condition.EncryptCondition;
 import org.apache.shardingsphere.encrypt.rewrite.condition.impl.EncryptBinaryCondition;
 import org.apache.shardingsphere.encrypt.rewrite.condition.impl.EncryptInCondition;
 import org.apache.shardingsphere.encrypt.rewrite.token.pojo.EncryptPredicateEqualRightValueToken;
+import org.apache.shardingsphere.encrypt.rewrite.token.pojo.EncryptPredicateFunctionRightValueToken;
 import org.apache.shardingsphere.encrypt.rewrite.token.pojo.EncryptPredicateInRightValueToken;
 import org.apache.shardingsphere.encrypt.rule.EncryptRule;
 import org.apache.shardingsphere.infra.binder.statement.SQLStatementContext;
@@ -34,6 +35,7 @@ import org.apache.shardingsphere.infra.database.type.DatabaseTypeEngine;
 import org.apache.shardingsphere.infra.rewrite.sql.token.generator.CollectionSQLTokenGenerator;
 import org.apache.shardingsphere.infra.rewrite.sql.token.generator.aware.ParametersAware;
 import org.apache.shardingsphere.infra.rewrite.sql.token.pojo.SQLToken;
+import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.FunctionSegment;
 
 import java.util.Collection;
 import java.util.HashMap;
@@ -90,6 +92,10 @@ public final class EncryptPredicateRightValueTokenGenerator
         int stopIndex = encryptCondition.getStopIndex();
         Map<Integer, Object> indexValues = getPositionValues(encryptCondition.getPositionValueMap().keySet(), getEncryptedValues(schemaName, encryptCondition, originalValues));
         Collection<Integer> parameterMarkerIndexes = encryptCondition.getPositionIndexMap().keySet();
+        if (encryptCondition instanceof EncryptBinaryCondition && ((EncryptBinaryCondition) encryptCondition).getExpressionSegment() instanceof FunctionSegment) {
+            return new EncryptPredicateFunctionRightValueToken(startIndex, stopIndex,
+                    ((FunctionSegment) ((EncryptBinaryCondition) encryptCondition).getExpressionSegment()).getFunctionName(), indexValues, parameterMarkerIndexes);
+        }
         return encryptCondition instanceof EncryptInCondition
                 ? new EncryptPredicateInRightValueToken(startIndex, stopIndex, indexValues, parameterMarkerIndexes)
                 : new EncryptPredicateEqualRightValueToken(startIndex, stopIndex, indexValues, parameterMarkerIndexes);
@@ -120,6 +126,10 @@ public final class EncryptPredicateRightValueTokenGenerator
             indexValues.putAll(getPositionValues(encryptCondition.getPositionValueMap().keySet(), getEncryptedValues(schemaName, encryptCondition, originalValues)));
         }
         Collection<Integer> parameterMarkerIndexes = encryptCondition.getPositionIndexMap().keySet();
+        if (encryptCondition instanceof EncryptBinaryCondition && ((EncryptBinaryCondition) encryptCondition).getExpressionSegment() instanceof FunctionSegment) {
+            return new EncryptPredicateFunctionRightValueToken(startIndex, stopIndex,
+                    ((FunctionSegment) ((EncryptBinaryCondition) encryptCondition).getExpressionSegment()).getFunctionName(), indexValues, parameterMarkerIndexes);
+        }
         return encryptCondition instanceof EncryptInCondition
                 ? new EncryptPredicateInRightValueToken(startIndex, stopIndex, indexValues, parameterMarkerIndexes)
                 : new EncryptPredicateEqualRightValueToken(startIndex, stopIndex, indexValues, parameterMarkerIndexes);
diff --git a/features/encrypt/core/src/main/java/org/apache/shardingsphere/encrypt/rewrite/token/pojo/EncryptPredicateFunctionRightValueToken.java b/features/encrypt/core/src/main/java/org/apache/shardingsphere/encrypt/rewrite/token/pojo/EncryptPredicateFunctionRightValueToken.java
new file mode 100644
index 00000000000..5ab6e115a86
--- /dev/null
+++ b/features/encrypt/core/src/main/java/org/apache/shardingsphere/encrypt/rewrite/token/pojo/EncryptPredicateFunctionRightValueToken.java
@@ -0,0 +1,82 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.shardingsphere.encrypt.rewrite.token.pojo;
+
+import lombok.Getter;
+import org.apache.shardingsphere.infra.rewrite.sql.token.pojo.SQLToken;
+import org.apache.shardingsphere.infra.rewrite.sql.token.pojo.Substitutable;
+
+import java.util.Collection;
+import java.util.Map;
+import java.util.Objects;
+
+/**
+ * Predicate in right value token for encrypt.
+ */
+public final class EncryptPredicateFunctionRightValueToken extends SQLToken implements Substitutable {
+    
+    @Getter
+    private final int stopIndex;
+    
+    private final String functionName;
+    
+    private final Map<Integer, Object> indexValues;
+    
+    private final Collection<Integer> paramMarkerIndexes;
+    
+    public EncryptPredicateFunctionRightValueToken(final int startIndex, final int stopIndex, final String functionName,
+                                                   final Map<Integer, Object> indexValues, final Collection<Integer> paramMarkerIndexes) {
+        super(startIndex);
+        this.stopIndex = stopIndex;
+        this.functionName = functionName;
+        this.indexValues = indexValues;
+        this.paramMarkerIndexes = paramMarkerIndexes;
+    }
+    
+    @Override
+    public String toString() {
+        StringBuilder result = new StringBuilder();
+        result.append(functionName).append(" (");
+        for (int i = 0; i < indexValues.size() + paramMarkerIndexes.size(); i++) {
+            if (paramMarkerIndexes.contains(i)) {
+                result.append("?");
+            } else {
+                if (indexValues.get(i) instanceof String) {
+                    result.append("'").append(indexValues.get(i)).append("'");
+                } else {
+                    result.append(indexValues.get(i));
+                }
+            }
+            result.append(", ");
+        }
+        result.delete(result.length() - 2, result.length()).append(")");
+        return result.toString();
+    }
+    
+    @Override
+    public boolean equals(final Object obj) {
+        return obj instanceof EncryptPredicateFunctionRightValueToken && ((EncryptPredicateFunctionRightValueToken) obj).getStartIndex() == getStartIndex()
+                && ((EncryptPredicateFunctionRightValueToken) obj).getStopIndex() == stopIndex && ((EncryptPredicateFunctionRightValueToken) obj).indexValues.equals(indexValues)
+                && ((EncryptPredicateFunctionRightValueToken) obj).paramMarkerIndexes.equals(paramMarkerIndexes);
+    }
+    
+    @Override
+    public int hashCode() {
+        return Objects.hash(getStartIndex(), stopIndex, indexValues, paramMarkerIndexes);
+    }
+}
diff --git a/features/encrypt/core/src/test/java/org/apache/shardingsphere/encrypt/rewrite/pojo/EncryptPredicateFunctionRightValueTokenTest.java b/features/encrypt/core/src/test/java/org/apache/shardingsphere/encrypt/rewrite/pojo/EncryptPredicateFunctionRightValueTokenTest.java
new file mode 100644
index 00000000000..c2a0658c0a6
--- /dev/null
+++ b/features/encrypt/core/src/test/java/org/apache/shardingsphere/encrypt/rewrite/pojo/EncryptPredicateFunctionRightValueTokenTest.java
@@ -0,0 +1,43 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.shardingsphere.encrypt.rewrite.pojo;
+
+import org.apache.shardingsphere.encrypt.rewrite.token.pojo.EncryptPredicateFunctionRightValueToken;
+import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.FunctionSegment;
+import org.junit.Test;
+
+import java.util.Collections;
+import java.util.LinkedHashMap;
+import java.util.Map;
+
+import static org.hamcrest.CoreMatchers.is;
+import static org.hamcrest.MatcherAssert.assertThat;
+
+public final class EncryptPredicateFunctionRightValueTokenTest {
+    
+    @Test
+    public void assertToStringWithoutPlaceholderWithoutTableOwnerWithFunction() {
+        Map<Integer, Object> indexValues = new LinkedHashMap<>();
+        indexValues.put(0, "%");
+        indexValues.put(1, "abc");
+        indexValues.put(2, "%");
+        FunctionSegment functionSegment = new FunctionSegment(0, 0, "CONCAT", "('%','abc','%')");
+        EncryptPredicateFunctionRightValueToken actual = new EncryptPredicateFunctionRightValueToken(0, 0, functionSegment.getFunctionName(), indexValues, Collections.emptyList());
+        assertThat(actual.toString(), is("CONCAT ('%', 'abc', '%')"));
+    }
+}
diff --git a/test/e2e/suite/src/test/resources/cases/dql/dql-integration-select-aggregate.xml b/test/e2e/suite/src/test/resources/cases/dql/dql-integration-select-aggregate.xml
index 17e433fce14..a18673c5446 100644
--- a/test/e2e/suite/src/test/resources/cases/dql/dql-integration-select-aggregate.xml
+++ b/test/e2e/suite/src/test/resources/cases/dql/dql-integration-select-aggregate.xml
@@ -133,4 +133,8 @@
     <test-case sql="SELECT MAX(p.price) AS max_price, MIN(p.price) AS min_price, SUM(p.price) AS sum_price, AVG(p.price) AS avg_price, COUNT(1) AS count FROM t_order o INNER JOIN t_order_item i ON o.order_id = i.order_id INNER JOIN t_product p ON i.product_id = p.product_id GROUP BY o.order_id HAVING SUM(p.price) > ? ORDER BY max_price" db-types="MySQL,PostgreSQL,openGauss" scenario-types="db">
         <assertion parameters="10000:int" expected-data-source-name="read_dataset" />
     </test-case>
+    
+    <test-case sql="SELECT * FROM t_merchant WHERE business_code LIKE CONCAT('%','abc','%')" db-types="MySQL,PostgreSQL,openGauss" scenario-types="encrypt">
+        <assertion expected-data-source-name="read_dataset" />
+    </test-case>
 </integration-test-cases>
diff --git a/test/it/rewriter/src/test/resources/scenario/encrypt/case/query-with-cipher/dml/select/select-where.xml b/test/it/rewriter/src/test/resources/scenario/encrypt/case/query-with-cipher/dml/select/select-where.xml
index 52b76eb0af8..002d81f3f79 100644
--- a/test/it/rewriter/src/test/resources/scenario/encrypt/case/query-with-cipher/dml/select/select-where.xml
+++ b/test/it/rewriter/src/test/resources/scenario/encrypt/case/query-with-cipher/dml/select/select-where.xml
@@ -83,4 +83,9 @@
         <input sql="SELECT a.account_id, a.password, a.amount AS a, a.status AS s FROM t_account_detail AS a WHERE a.account_id = 1 AND a.password = 'aaa' AND a.amount = 1000 AND a.status = 'OK'" />
         <output sql="SELECT a.account_id, a.plain_password AS password, a.plain_amount AS a, a.status AS s FROM t_account_detail AS a WHERE a.account_id = 1 AND a.plain_password = 'aaa' AND a.plain_amount = 1000 AND a.status = 'OK'" />
     </rewrite-assertion>
+    
+    <rewrite-assertion id="select_where_with_cipher_column_like_concat_for_literals" db-types="PostgreSQL,openGauss">
+        <input sql="SELECT a.account_id, a.password, a.amount AS a, a.status AS s FROM t_account_bak AS a WHERE a.account_id = 1 AND a.certificate_number like concat('%','abc','%')" />
+        <output sql="SELECT a.account_id, a.cipher_password AS password, a.cipher_amount AS a, a.status AS s FROM t_account_bak AS a WHERE a.account_id = 1 AND a.like_query_certificate_number like concat ('like_query_%', 'like_query_abc', 'like_query_%')" />
+    </rewrite-assertion>
 </rewrite-assertions>
diff --git a/test/it/rewriter/src/test/resources/scenario/encrypt/case/query-with-plain/dml/select/select-where.xml b/test/it/rewriter/src/test/resources/scenario/encrypt/case/query-with-plain/dml/select/select-where.xml
index 6a743747ac6..4849fe534c7 100644
--- a/test/it/rewriter/src/test/resources/scenario/encrypt/case/query-with-plain/dml/select/select-where.xml
+++ b/test/it/rewriter/src/test/resources/scenario/encrypt/case/query-with-plain/dml/select/select-where.xml
@@ -84,4 +84,9 @@
         <input sql="SELECT a.account_id, a.password, a.amount AS a, a.status AS s FROM t_account_detail AS a WHERE a.account_id = 1 AND a.password = 'aaa' AND a.password like 'aaa' AND a.amount = 1000 AND a.status = 'OK'" />
         <output sql="SELECT a.account_id, a.cipher_password AS password, a.cipher_amount AS a, a.status AS s FROM t_account_detail AS a WHERE a.account_id = 1 AND a.assisted_query_password = 'assisted_query_aaa' AND a.like_query_password like 'like_query_aaa' AND a.cipher_amount = 'encrypt_1000' AND a.status = 'OK'" />
     </rewrite-assertion>
+    
+    <rewrite-assertion id="select_where_with_plain_column_like_concat_for_literals" db-types="PostgreSQL,openGauss">
+        <input sql="SELECT a.account_id, a.password, a.amount AS a, a.status AS s FROM t_account_bak AS a WHERE a.account_id = 1 AND a.certificate_number like concat('%','abc','%')" />
+        <output sql="SELECT a.account_id, a.plain_password AS password, a.plain_amount AS a, a.status AS s FROM t_account_bak AS a WHERE a.account_id = 1 AND a.plain_certificate_number like concat ('%', 'abc', '%')" />
+    </rewrite-assertion>
 </rewrite-assertions>