You are viewing a plain text version of this content. The canonical link for it is here.
Posted to notifications@shardingsphere.apache.org by pa...@apache.org on 2021/09/28 10:45:55 UTC

[shardingsphere] branch master updated: optimize encrypt code style (#12793)

This is an automated email from the ASF dual-hosted git repository.

panjuan pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/shardingsphere.git


The following commit(s) were added to refs/heads/master by this push:
     new e2a091d  optimize encrypt code style (#12793)
e2a091d is described below

commit e2a091d4c4a6389fdbcd8b7fd27cfebe18898a88
Author: Zhengqiang Duan <du...@apache.org>
AuthorDate: Tue Sep 28 18:45:06 2021 +0800

    optimize encrypt code style (#12793)
---
 .../impl/EncryptPredicateColumnTokenGenerator.java | 40 +++++++++-------------
 .../impl/EncryptProjectionTokenGenerator.java      | 25 ++++++--------
 .../common/segment/dml/predicate/WhereSegment.java |  1 -
 .../parser/sql/common/util/WhereExtractUtil.java   | 30 +++++++---------
 4 files changed, 40 insertions(+), 56 deletions(-)

diff --git a/shardingsphere-features/shardingsphere-encrypt/shardingsphere-encrypt-core/src/main/java/org/apache/shardingsphere/encrypt/rewrite/token/generator/impl/EncryptPredicateColumnTokenGenerator.java b/shardingsphere-features/shardingsphere-encrypt/shardingsphere-encrypt-core/src/main/java/org/apache/shardingsphere/encrypt/rewrite/token/generator/impl/EncryptPredicateColumnTokenGenerator.java
index d6ed536..528ef12 100644
--- a/shardingsphere-features/shardingsphere-encrypt/shardingsphere-encrypt-core/src/main/java/org/apache/shardingsphere/encrypt/rewrite/token/generator/impl/EncryptPredicateColumnTokenGenerator.java
+++ b/shardingsphere-features/shardingsphere-encrypt/shardingsphere-encrypt-core/src/main/java/org/apache/shardingsphere/encrypt/rewrite/token/generator/impl/EncryptPredicateColumnTokenGenerator.java
@@ -59,33 +59,16 @@ public final class EncryptPredicateColumnTokenGenerator extends BaseEncryptSQLTo
     
     @Override
     protected boolean isGenerateSQLTokenForEncrypt(final SQLStatementContext sqlStatementContext) {
-        return isGenerateSQLTokenForEncryptOnWhereAvailable(sqlStatementContext) || isGenerateSQLTokenForEncryptOnJoinSegments(sqlStatementContext);
-    }
-    
-    private boolean isGenerateSQLTokenForEncryptOnWhereAvailable(final SQLStatementContext sqlStatementContext) {
-        return sqlStatementContext instanceof WhereAvailable && ((WhereAvailable) sqlStatementContext).getWhere().isPresent();
-    }
-    
-    private boolean isGenerateSQLTokenForEncryptOnJoinSegments(final SQLStatementContext sqlStatementContext) {
-        return sqlStatementContext instanceof SelectStatementContext && ((SelectStatementContext) sqlStatementContext).isContainsJoinQuery();
+        return (sqlStatementContext instanceof WhereAvailable && ((WhereAvailable) sqlStatementContext).getWhere().isPresent())
+            || (sqlStatementContext instanceof SelectStatementContext && ((SelectStatementContext) sqlStatementContext).isContainsJoinQuery());
     }
     
     @Override
     public Collection<SubstitutableColumnNameToken> generateSQLTokens(final SQLStatementContext sqlStatementContext) {
-        Collection<SubstitutableColumnNameToken> result = new LinkedHashSet<>();
-        Collection<AndPredicate> andPredicates = new LinkedHashSet<>();
-        if (isGenerateSQLTokenForEncryptOnWhereAvailable(sqlStatementContext)) {
-            ExpressionSegment expression = ((WhereAvailable) sqlStatementContext).getWhere().get().getExpr();
-            andPredicates.addAll(ExpressionExtractUtil.getAndPredicates(expression));
-        }
-        Collection<WhereSegment> whereSegments = Collections.emptyList();
-        if (sqlStatementContext instanceof SelectStatementContext) {
-            whereSegments = WhereExtractUtil.getJoinWhereSegments((SelectStatement) sqlStatementContext.getSqlStatement());
-            andPredicates.addAll(whereSegments.stream().map(each -> ExpressionExtractUtil.getAndPredicates(each.getExpr())).flatMap(Collection::stream).collect(Collectors.toList()));
-        }
+        Collection<WhereSegment> whereSegments = getWhereSegments(sqlStatementContext);
+        Collection<AndPredicate> andPredicates = whereSegments.stream().flatMap(each -> ExpressionExtractUtil.getAndPredicates(each.getExpr()).stream()).collect(Collectors.toList());
         Map<String, String> columnTableNames = getColumnTableNames(sqlStatementContext, andPredicates, whereSegments);
-        result.addAll(andPredicates.stream().map(each -> generateSQLTokens(each.getPredicates(), columnTableNames)).flatMap(Collection::stream).collect(Collectors.toList()));
-        return result;
+        return andPredicates.stream().flatMap(each -> generateSQLTokens(each.getPredicates(), columnTableNames).stream()).collect(Collectors.toCollection(LinkedHashSet::new));
     }
     
     private Collection<SubstitutableColumnNameToken> generateSQLTokens(final Collection<ExpressionSegment> predicates, final Map<String, String> columnTableNames) {
@@ -115,11 +98,22 @@ public final class EncryptPredicateColumnTokenGenerator extends BaseEncryptSQLTo
         return result;
     }
     
+    private Collection<WhereSegment> getWhereSegments(final SQLStatementContext<?> sqlStatementContext) {
+        Collection<WhereSegment> result = new LinkedList<>();
+        if (sqlStatementContext instanceof WhereAvailable && ((WhereAvailable) sqlStatementContext).getWhere().isPresent()) {
+            result.add(((WhereAvailable) sqlStatementContext).getWhere().get());
+        }
+        if (sqlStatementContext instanceof SelectStatementContext && ((SelectStatementContext) sqlStatementContext).isContainsJoinQuery()) {
+            result.addAll(WhereExtractUtil.getJoinWhereSegments((SelectStatement) sqlStatementContext.getSqlStatement()));
+        }
+        return result;
+    }
+    
     private Map<String, String> getColumnTableNames(final SQLStatementContext<?> sqlStatementContext, final Collection<AndPredicate> andPredicates, 
             final Collection<WhereSegment> whereSegments) {
         Collection<ColumnSegment> columns = andPredicates.stream().flatMap(each -> each.getPredicates().stream())
                 .flatMap(each -> ColumnExtractor.extract(each).stream()).filter(Objects::nonNull).collect(Collectors.toList());
-        columns.addAll(whereSegments.stream().map(each -> ColumnExtractor.extract(each.getExpr())).flatMap(Collection::stream).collect(Collectors.toList()));
+        columns.addAll(whereSegments.stream().flatMap(each -> ColumnExtractor.extract(each.getExpr()).stream()).collect(Collectors.toList()));
         return sqlStatementContext.getTablesContext().findTableName(columns, schema);
     }
     
diff --git a/shardingsphere-features/shardingsphere-encrypt/shardingsphere-encrypt-core/src/main/java/org/apache/shardingsphere/encrypt/rewrite/token/generator/impl/EncryptProjectionTokenGenerator.java b/shardingsphere-features/shardingsphere-encrypt/shardingsphere-encrypt-core/src/main/java/org/apache/shardingsphere/encrypt/rewrite/token/generator/impl/EncryptProjectionTokenGenerator.java
index 71bfc17..1144dac 100644
--- a/shardingsphere-features/shardingsphere-encrypt/shardingsphere-encrypt-core/src/main/java/org/apache/shardingsphere/encrypt/rewrite/token/generator/impl/EncryptProjectionTokenGenerator.java
+++ b/shardingsphere-features/shardingsphere-encrypt/shardingsphere-encrypt-core/src/main/java/org/apache/shardingsphere/encrypt/rewrite/token/generator/impl/EncryptProjectionTokenGenerator.java
@@ -17,13 +17,7 @@
 
 package org.apache.shardingsphere.encrypt.rewrite.token.generator.impl;
 
-import java.util.Collection;
-import java.util.Collections;
-import java.util.HashSet;
-import java.util.LinkedList;
-import java.util.List;
-import java.util.Optional;
-
+import lombok.Setter;
 import org.apache.shardingsphere.encrypt.rewrite.aware.QueryWithCipherColumnAware;
 import org.apache.shardingsphere.encrypt.rewrite.token.generator.BaseEncryptSQLTokenGenerator;
 import org.apache.shardingsphere.encrypt.rule.EncryptTable;
@@ -44,7 +38,12 @@ import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.item.Projecti
 import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.item.ShorthandProjectionSegment;
 import org.apache.shardingsphere.sql.parser.sql.common.segment.generic.OwnerSegment;
 
-import lombok.Setter;
+import java.util.Collection;
+import java.util.Collections;
+import java.util.LinkedHashSet;
+import java.util.LinkedList;
+import java.util.List;
+import java.util.Optional;
 
 /**
  * Projection token generator for encrypt.
@@ -65,13 +64,9 @@ public final class EncryptProjectionTokenGenerator extends BaseEncryptSQLTokenGe
     @Override
     public Collection<SubstitutableColumnNameToken> generateSQLTokens(final SelectStatementContext selectStatementContext) {
         ProjectionsSegment projectionsSegment = selectStatementContext.getSqlStatement().getProjections();
-        Collection<SubstitutableColumnNameToken> result = new HashSet<>();
-        Collection<String> tableNames = selectStatementContext.getTablesContext().getTableNames();
-        for (String each : tableNames) {
-            Optional<EncryptTable> encryptTable = getEncryptRule().findEncryptTable(each);
-            if (encryptTable.isPresent()) {
-                result.addAll(generateSQLTokens(projectionsSegment, each, selectStatementContext, encryptTable.get()));
-            }
+        Collection<SubstitutableColumnNameToken> result = new LinkedHashSet<>();
+        for (String each : selectStatementContext.getTablesContext().getTableNames()) {
+            getEncryptRule().findEncryptTable(each).map(optional -> generateSQLTokens(projectionsSegment, each, selectStatementContext, optional)).ifPresent(result::addAll);
         }
         return result;
     }
diff --git a/shardingsphere-sql-parser/shardingsphere-sql-parser-statement/src/main/java/org/apache/shardingsphere/sql/parser/sql/common/segment/dml/predicate/WhereSegment.java b/shardingsphere-sql-parser/shardingsphere-sql-parser-statement/src/main/java/org/apache/shardingsphere/sql/parser/sql/common/segment/dml/predicate/WhereSegment.java
index 9ec89c0..299d152 100644
--- a/shardingsphere-sql-parser/shardingsphere-sql-parser-statement/src/main/java/org/apache/shardingsphere/sql/parser/sql/common/segment/dml/predicate/WhereSegment.java
+++ b/shardingsphere-sql-parser/shardingsphere-sql-parser-statement/src/main/java/org/apache/shardingsphere/sql/parser/sql/common/segment/dml/predicate/WhereSegment.java
@@ -20,7 +20,6 @@ package org.apache.shardingsphere.sql.parser.sql.common.segment.dml.predicate;
 import lombok.Getter;
 import lombok.RequiredArgsConstructor;
 import lombok.Setter;
-
 import org.apache.shardingsphere.sql.parser.sql.common.segment.SQLSegment;
 import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.ExpressionSegment;
 
diff --git a/shardingsphere-sql-parser/shardingsphere-sql-parser-statement/src/main/java/org/apache/shardingsphere/sql/parser/sql/common/util/WhereExtractUtil.java b/shardingsphere-sql-parser/shardingsphere-sql-parser-statement/src/main/java/org/apache/shardingsphere/sql/parser/sql/common/util/WhereExtractUtil.java
index 45cc019..a5945a8 100644
--- a/shardingsphere-sql-parser/shardingsphere-sql-parser-statement/src/main/java/org/apache/shardingsphere/sql/parser/sql/common/util/WhereExtractUtil.java
+++ b/shardingsphere-sql-parser/shardingsphere-sql-parser-statement/src/main/java/org/apache/shardingsphere/sql/parser/sql/common/util/WhereExtractUtil.java
@@ -17,10 +17,8 @@
 
 package org.apache.shardingsphere.sql.parser.sql.common.util;
 
-import java.util.Collection;
-import java.util.Collections;
-import java.util.LinkedList;
-
+import lombok.AccessLevel;
+import lombok.NoArgsConstructor;
 import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.ExpressionSegment;
 import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.subquery.SubquerySegment;
 import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.predicate.WhereSegment;
@@ -28,8 +26,9 @@ import org.apache.shardingsphere.sql.parser.sql.common.segment.generic.table.Joi
 import org.apache.shardingsphere.sql.parser.sql.common.segment.generic.table.TableSegment;
 import org.apache.shardingsphere.sql.parser.sql.common.statement.dml.SelectStatement;
 
-import lombok.AccessLevel;
-import lombok.NoArgsConstructor;
+import java.util.Collection;
+import java.util.Collections;
+import java.util.LinkedList;
 
 /**
  * Where extract utility class.
@@ -47,24 +46,21 @@ public final class WhereExtractUtil {
         if (null == selectStatement.getFrom()) {
             return Collections.emptyList();
         }
-        TableSegment tableSegment = selectStatement.getFrom();
+        return getJoinWhereSegments(selectStatement.getFrom());
+    }
+    
+    private static Collection<WhereSegment> getJoinWhereSegments(final TableSegment tableSegment) {
         if (!(tableSegment instanceof JoinTableSegment) || null == ((JoinTableSegment) tableSegment).getCondition()) {
             return Collections.emptyList();
         }
+        JoinTableSegment joinTableSegment = (JoinTableSegment) tableSegment;
         Collection<WhereSegment> result = new LinkedList<>();
-        processJoinTableSegment(tableSegment, result);
+        result.add(generateWhereSegment(joinTableSegment));
+        result.addAll(getJoinWhereSegments(joinTableSegment.getLeft()));
+        result.addAll(getJoinWhereSegments(joinTableSegment.getRight()));
         return result;
     }
     
-    private static void processJoinTableSegment(final TableSegment tableSegment, final Collection<WhereSegment> whereSegments) {
-        if (null == tableSegment || !(tableSegment instanceof JoinTableSegment) || null == ((JoinTableSegment) tableSegment).getCondition()) {
-            return;
-        }
-        JoinTableSegment joinTableSegment = (JoinTableSegment) tableSegment;
-        whereSegments.add(generateWhereSegment(joinTableSegment));
-        processJoinTableSegment(joinTableSegment.getLeft(), whereSegments);
-    }
-    
     private static WhereSegment generateWhereSegment(final JoinTableSegment joinTableSegment) {
         ExpressionSegment expressionSegment = joinTableSegment.getCondition();
         return new WhereSegment(expressionSegment.getStartIndex(), expressionSegment.getStopIndex(), expressionSegment);