You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@doris.apache.org by ca...@apache.org on 2022/12/27 10:40:02 UTC

[doris] branch master updated: [enhancement](rewrite) add OrToIn rule and fix ExtractCommonFactorsRule apply problems (#12872)

This is an automated email from the ASF dual-hosted git repository.

caiconghui pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/doris.git


The following commit(s) were added to refs/heads/master by this push:
     new 0550dfaeb2 [enhancement](rewrite) add OrToIn rule and fix ExtractCommonFactorsRule apply problems (#12872)
0550dfaeb2 is described below

commit 0550dfaeb2a419c8f5730f3c2e7c84a1fe67f2c4
Author: Henry2SS <45...@users.noreply.github.com>
AuthorDate: Tue Dec 27 18:39:53 2022 +0800

    [enhancement](rewrite) add OrToIn rule and fix ExtractCommonFactorsRule apply problems (#12872)
    
    Co-authored-by: wuhangze <wu...@jd.com>
---
 .../doris/rewrite/ExtractCommonFactorsRule.java    | 101 ++++++++++++++++++++-
 .../doris/analysis/ListPartitionPrunerTest.java    |   6 +-
 .../doris/analysis/RangePartitionPruneTest.java    |  14 +--
 .../org/apache/doris/planner/QueryPlanTest.java    |  21 +++++
 .../ExtractCommonFactorsRuleFunctionTest.java      |   1 -
 .../data/performance_p0/redundant_conjuncts.out    |   2 +-
 6 files changed, 129 insertions(+), 16 deletions(-)

diff --git a/fe/fe-core/src/main/java/org/apache/doris/rewrite/ExtractCommonFactorsRule.java b/fe/fe-core/src/main/java/org/apache/doris/rewrite/ExtractCommonFactorsRule.java
index 6ff72d858b..5a3bc34c8c 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/rewrite/ExtractCommonFactorsRule.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/rewrite/ExtractCommonFactorsRule.java
@@ -25,6 +25,7 @@ import org.apache.doris.analysis.Expr;
 import org.apache.doris.analysis.InPredicate;
 import org.apache.doris.analysis.LiteralExpr;
 import org.apache.doris.analysis.SlotRef;
+import org.apache.doris.analysis.TableName;
 import org.apache.doris.common.AnalysisException;
 import org.apache.doris.planner.PlanNode;
 import org.apache.doris.rewrite.ExprRewriter.ClauseType;
@@ -68,6 +69,7 @@ public class ExtractCommonFactorsRule implements ExprRewriteRule {
 
     @Override
     public Expr apply(Expr expr, Analyzer analyzer, ExprRewriter.ClauseType clauseType) throws AnalysisException {
+        Expr resultExpr = null;
         if (expr == null) {
             return null;
         } else if (expr instanceof CompoundPredicate
@@ -77,12 +79,19 @@ public class ExtractCommonFactorsRule implements ExprRewriteRule {
                 return rewrittenExpr;
             }
         } else {
-            for (int i = 0; i < expr.getChildren().size(); i++) {
+            if (!(expr instanceof CompoundPredicate)) {
+                return expr;
+            }
+
+            resultExpr = expr.clone();
+
+            for (int i = 0; i < resultExpr.getChildren().size(); i++) {
                 Expr rewrittenExpr = apply(expr.getChild(i), analyzer, clauseType);
                 if (rewrittenExpr != null) {
-                    expr.setChild(i, rewrittenExpr);
+                    resultExpr.setChild(i, rewrittenExpr);
                 }
             }
+            return resultExpr;
         }
         return expr;
     }
@@ -179,10 +188,10 @@ public class ExtractCommonFactorsRule implements ExprRewriteRule {
         if (CollectionUtils.isNotEmpty(commonFactorList)) {
             result = new CompoundPredicate(CompoundPredicate.Operator.AND,
                     makeCompound(commonFactorList, CompoundPredicate.Operator.AND),
-                    makeCompound(remainingOrClause, CompoundPredicate.Operator.OR));
+                    makeCompoundRemaining(remainingOrClause, CompoundPredicate.Operator.OR));
             result.setPrintSqlInParens(true);
         } else {
-            result = makeCompound(remainingOrClause, CompoundPredicate.Operator.OR);
+            result = makeCompoundRemaining(remainingOrClause, CompoundPredicate.Operator.OR);
         }
         if (LOG.isDebugEnabled()) {
             LOG.debug("equal ors: " + result.toSql());
@@ -399,6 +408,11 @@ public class ExtractCommonFactorsRule implements ExprRewriteRule {
 
     /**
      * Rebuild CompoundPredicate, [a, e, f] AND => a and e and f
+     * Rewrite  OR :[a, b, c]
+     *          while (a.columnName == b.columnName == c.columnName) && (a,b,c)
+     *          instance of (BinaryPredicate, InPredicate)
+     *          && (a,b,c).op = BinaryPredicate.Operator.EQ =======>>>>>>
+     *          =======>>>>>>  columnName IN (a.value,b.value,c.value)
      */
     private Expr makeCompound(List<Expr> exprs, CompoundPredicate.Operator op) {
         if (CollectionUtils.isEmpty(exprs)) {
@@ -415,6 +429,85 @@ public class ExtractCommonFactorsRule implements ExprRewriteRule {
         return result;
     }
 
+    private Expr makeCompoundRemaining(List<Expr> exprs, CompoundPredicate.Operator op) {
+        if (CollectionUtils.isEmpty(exprs)) {
+            return null;
+        }
+        if (exprs.size() == 1) {
+            return exprs.get(0);
+        }
+
+        Expr rewritePredicate = null;
+        // only OR will be rewrite to IN
+        if (op == CompoundPredicate.Operator.OR) {
+            rewritePredicate = rewriteOrToIn(exprs);
+            // IF rewrite finished, rewritePredicate will not be null
+            // IF not rewrite, do compoundPredicate
+            if (rewritePredicate != null) {
+                return rewritePredicate;
+            }
+        }
+
+        CompoundPredicate result = new CompoundPredicate(op, exprs.get(0), exprs.get(1));
+        for (int i = 2; i < exprs.size(); i++) {
+            result = new CompoundPredicate(op, result.clone(), exprs.get(i));
+        }
+        result.setPrintSqlInParens(true);
+        return result;
+    }
+
+    private Expr rewriteOrToIn(List<Expr> exprs) {
+        // remainingOR  expr = BP IP
+        InPredicate inPredicate = null;
+        boolean isOrToInAllowed = true;
+        Set<String> slotSet = new LinkedHashSet<>();
+
+        for (int i = 0; i < exprs.size(); i++) {
+            Expr predicate = exprs.get(i);
+            if (!(predicate instanceof BinaryPredicate) && !(predicate instanceof InPredicate)) {
+                isOrToInAllowed = false;
+                break;
+            } else if (!(predicate.getChild(0) instanceof SlotRef)) {
+                isOrToInAllowed = false;
+                break;
+            } else if (!(predicate.getChild(1) instanceof LiteralExpr)) {
+                isOrToInAllowed = false;
+                break;
+            } else if (predicate instanceof BinaryPredicate
+                    && ((BinaryPredicate) predicate).getOp() != BinaryPredicate.Operator.EQ) {
+                isOrToInAllowed = false;
+                break;
+            } else {
+                TableName tableName = ((SlotRef) predicate.getChild(0)).getTableName();
+                if (tableName != null) {
+                    String tblName = tableName.toString();
+                    String columnWithTable = tblName + "." + ((SlotRef) predicate.getChild(0)).getColumnName();
+                    slotSet.add(columnWithTable);
+                } else {
+                    slotSet.add(((SlotRef) predicate.getChild(0)).getColumnName());
+                }
+            }
+        }
+
+        // isOrToInAllowed : true, means can rewrite
+        // slotSet.size : nums of columnName in exprs, should be 1
+        if (isOrToInAllowed && slotSet.size() == 1) {
+            // slotRef to get ColumnName
+
+            // SlotRef firstSlot = (SlotRef) exprs.get(0).getChild(0);
+            List<Expr> childrenList = exprs.get(0).getChildren();
+            inPredicate = new InPredicate(exprs.get(0).getChild(0),
+                    childrenList.subList(1, childrenList.size()), false);
+
+            for (int i = 1; i < exprs.size(); i++) {
+                childrenList = exprs.get(i).getChildren();
+                inPredicate.addChildren(childrenList.subList(1, childrenList.size()));
+            }
+        }
+
+        return inPredicate;
+    }
+
     /**
      * Convert RangeSet to Compound Predicate
      * @param slotRef: <k1>
diff --git a/fe/fe-core/src/test/java/org/apache/doris/analysis/ListPartitionPrunerTest.java b/fe/fe-core/src/test/java/org/apache/doris/analysis/ListPartitionPrunerTest.java
index d7bae60c8f..a377ef0d68 100644
--- a/fe/fe-core/src/test/java/org/apache/doris/analysis/ListPartitionPrunerTest.java
+++ b/fe/fe-core/src/test/java/org/apache/doris/analysis/ListPartitionPrunerTest.java
@@ -109,9 +109,9 @@ public class ListPartitionPrunerTest extends PartitionPruneTestBase {
         addCase("select * from test.t4 where k1 >= 2 and k2 = \"shanghai\";", "partitions=2/3", "partitions=1/3");
 
         // Disjunctive predicates
-        addCase("select * from test.t2 where k1=1 or k1=4", "partitions=3/3", "partitions=2/3");
-        addCase("select * from test.t4 where k1=1 or k1=3", "partitions=3/3", "partitions=2/3");
-        addCase("select * from test.t4 where k2=\"tianjin\" or k2=\"shanghai\"", "partitions=3/3", "partitions=2/3");
+        addCase("select * from test.t2 where k1=1 or k1=4", "partitions=2/3", "partitions=2/3");
+        addCase("select * from test.t4 where k1=1 or k1=3", "partitions=2/3", "partitions=2/3");
+        addCase("select * from test.t4 where k2=\"tianjin\" or k2=\"shanghai\"", "partitions=2/3", "partitions=2/3");
         addCase("select * from test.t4 where k1 > 1 or k2 < \"shanghai\"", "partitions=3/3", "partitions=3/3");
     }
 
diff --git a/fe/fe-core/src/test/java/org/apache/doris/analysis/RangePartitionPruneTest.java b/fe/fe-core/src/test/java/org/apache/doris/analysis/RangePartitionPruneTest.java
index 8c60f543f4..f1bbc3ba91 100644
--- a/fe/fe-core/src/test/java/org/apache/doris/analysis/RangePartitionPruneTest.java
+++ b/fe/fe-core/src/test/java/org/apache/doris/analysis/RangePartitionPruneTest.java
@@ -171,19 +171,19 @@ public class RangePartitionPruneTest extends PartitionPruneTestBase {
         addCase("select * from test.multi_not_null where k1 > 10 and k1 is null", "partitions=0/2", "partitions=0/2");
         // others predicates combination
         addCase("select * from test.t2 where k1 > 10 and k2 < 4", "partitions=6/9", "partitions=6/9");
-        addCase("select * from test.t2 where k1 >10 and k1 < 10 and (k1=11 or k1=12)", "partitions=0/9", "partitions=0/9");
+        addCase("select * from test.t2 where k1 >10 and k1 < 10 and (k1=11 or k1=12)", "partitions=1/9", "partitions=0/9");
         addCase("select * from test.t2 where k1 > 20 and k1 < 7 and k1 = 10", "partitions=0/9", "partitions=0/9");
 
         // 4. Disjunctive predicates
-        addCase("select * from test.t2 where k1=10 or k1=23", "partitions=9/9", "partitions=3/9");
-        addCase("select * from test.t2 where (k1=10 or k1=23) and (k2=4 or k2=5)", "partitions=9/9", "partitions=1/9");
-        addCase("select * from test.t2 where (k1=10 or k1=23) and (k2=4 or k2=11)", "partitions=9/9", "partitions=2/9");
-        addCase("select * from test.t2 where (k1=10 or k1=23) and (k2=3 or k2=4 or k2=11)", "partitions=9/9", "partitions=3/9");
-        addCase("select * from test.t1 where dt=20211123 or dt=20211124", "partitions=8/8", "partitions=2/8");
+        addCase("select * from test.t2 where k1=10 or k1=23", "partitions=3/9", "partitions=3/9");
+        addCase("select * from test.t2 where (k1=10 or k1=23) and (k2=4 or k2=5)", "partitions=1/9", "partitions=1/9");
+        addCase("select * from test.t2 where (k1=10 or k1=23) and (k2=4 or k2=11)", "partitions=2/9", "partitions=2/9");
+        addCase("select * from test.t2 where (k1=10 or k1=23) and (k2=3 or k2=4 or k2=11)", "partitions=3/9", "partitions=3/9");
+        addCase("select * from test.t1 where dt=20211123 or dt=20211124", "partitions=2/8", "partitions=2/8");
         addCase("select * from test.t1 where ((dt=20211123 and k1=1) or (dt=20211125 and k1=3))", "partitions=8/8", "partitions=2/8");
         // TODO: predicates are "PREDICATES: ((`dt` = 20211123 AND `k1` = 1) OR (`dt` = 20211125 AND `k1` = 3)), `k2` > ",
         // maybe something goes wrong with ExtractCommonFactorsRule.
-        addCase("select * from test.t1 where ((dt=20211123 and k1=1) or (dt=20211125 and k1=3)) and k2>0", "partitions=8/8", "partitions=8/8");
+        addCase("select * from test.t1 where ((dt=20211123 and k1=1) or (dt=20211125 and k1=3)) and k2>0", "partitions=8/8", "partitions=2/8");
         addCase("select * from test.t2 where k1 > 10 or k2 < 1", "partitions=9/9", "partitions=9/9");
         // add some cases for CompoundPredicate
         addCase("select * from test.t1 where (dt >= 20211121 and dt <= 20211122) or (dt >= 20211123 and dt <= 20211125)",
diff --git a/fe/fe-core/src/test/java/org/apache/doris/planner/QueryPlanTest.java b/fe/fe-core/src/test/java/org/apache/doris/planner/QueryPlanTest.java
index 56b68df235..d196d19195 100644
--- a/fe/fe-core/src/test/java/org/apache/doris/planner/QueryPlanTest.java
+++ b/fe/fe-core/src/test/java/org/apache/doris/planner/QueryPlanTest.java
@@ -2219,4 +2219,25 @@ public class QueryPlanTest extends TestWithFeService {
         String explainString = getSQLPlanOrErrorMsg(queryBaseTableStr);
         Assert.assertTrue(explainString.contains("PREAGGREGATION: ON"));
     }
+
+    @Test
+    public void testRewriteOrToIn() throws Exception {
+        connectContext.setDatabase("default_cluster:test");
+        String sql = "SELECT * from test1 where query_time = 1 or query_time = 2 or query_time in (3, 4)";
+        String explainString = UtFrameUtils.getSQLPlanOrErrorMsg(connectContext, "EXPLAIN " + sql);
+        Assert.assertTrue(explainString.contains("PREDICATES: `query_time` IN (1, 2, 3, 4)"));
+
+        sql = "SELECT * from test1 where (query_time = 1 or query_time = 2) and query_time in (3, 4)";
+        explainString = UtFrameUtils.getSQLPlanOrErrorMsg(connectContext, "EXPLAIN " + sql);
+        Assert.assertTrue(explainString.contains("PREDICATES: `query_time` IN (1, 2), `query_time` IN (3, 4)"));
+
+        sql = "SELECT * from test1 where (query_time = 1 or query_time = 2 or scan_bytes = 2) and scan_bytes in (2, 3)";
+        explainString = UtFrameUtils.getSQLPlanOrErrorMsg(connectContext, "EXPLAIN " + sql);
+        Assert.assertTrue(explainString.contains("PREDICATES: (`query_time` IN (1, 2) OR `scan_bytes` = 2), `scan_bytes` IN (2, 3)"));
+
+        sql = "SELECT * from test1 where (query_time = 1 or query_time = 2) and (scan_bytes = 2 or scan_bytes = 3)";
+        explainString = UtFrameUtils.getSQLPlanOrErrorMsg(connectContext, "EXPLAIN " + sql);
+        Assert.assertTrue(explainString.contains("PREDICATES: `query_time` IN (1, 2), `scan_bytes` IN (2, 3)")
+                || explainString.contains("PREDICATES: `query_time` IN (1, 2), `scan_bytes` IN (3, 2)"));
+    }
 }
diff --git a/fe/fe-core/src/test/java/org/apache/doris/rewrite/ExtractCommonFactorsRuleFunctionTest.java b/fe/fe-core/src/test/java/org/apache/doris/rewrite/ExtractCommonFactorsRuleFunctionTest.java
index bb6807abee..588df5c6a6 100644
--- a/fe/fe-core/src/test/java/org/apache/doris/rewrite/ExtractCommonFactorsRuleFunctionTest.java
+++ b/fe/fe-core/src/test/java/org/apache/doris/rewrite/ExtractCommonFactorsRuleFunctionTest.java
@@ -83,7 +83,6 @@ public class ExtractCommonFactorsRuleFunctionTest {
         Assert.assertEquals(1, StringUtils.countMatches(planString, "`tb1`.`k1` = `tb2`.`k1`"));
     }
 
-
     @Test
     public void testWideCommonFactorsWithOrPredicate() throws Exception {
         String query = "select * from tb1 where tb1.k1 > 1000 or tb1.k1 < 200 or tb1.k1 = 300";
diff --git a/regression-test/data/performance_p0/redundant_conjuncts.out b/regression-test/data/performance_p0/redundant_conjuncts.out
index 98178f31aa..7dbabccf37 100644
--- a/regression-test/data/performance_p0/redundant_conjuncts.out
+++ b/regression-test/data/performance_p0/redundant_conjuncts.out
@@ -23,7 +23,7 @@ PLAN FRAGMENT 0
 
   0:VOlapScanNode
      TABLE: default_cluster:regression_test_performance_p0.redundant_conjuncts(redundant_conjuncts), PREAGGREGATION: OFF. Reason: No AggregateInfo
-     PREDICATES: (`k1` = 1 OR `k1` = 2)
+     PREDICATES: `k1` IN (1, 2)
      partitions=0/1, tablets=0/0, tabletList=
      cardinality=0, avgRowSize=8.0, numNodes=1
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@doris.apache.org
For additional commands, e-mail: commits-help@doris.apache.org