You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@doris.apache.org by mo...@apache.org on 2023/01/23 14:18:02 UTC

[doris] 04/05: [enhance](planner)convert 'or' into 'in-predicate' (#15737)

This is an automated email from the ASF dual-hosted git repository.

morningman pushed a commit to branch branch-1.2-lts
in repository https://gitbox.apache.org/repos/asf/doris.git

commit 6ec00a340caa1f77876472bc5d18e6ca540b8153
Author: minghong <en...@gmail.com>
AuthorDate: Wed Jan 18 12:33:20 2023 +0800

    [enhance](planner)convert 'or' into 'in-predicate' (#15737)
    
    in previous [PR 12872](https://github.com/apache/doris/pull/12872), we convert multi equals on same slot into `in predicate`. for example, `a =1 or a = 2` => `a in (1, 2)`
    
    This pr makes 4 changes about convert or to in:
    1. fix a bug: `Not IN`  is merged with equal. `a =1 or a not in (2, 3)` => `a in (1, 2, 3)`
    2. extends this rule on more cases
      - merge for more than one slot: 'a =1 or a = 2 or b = 3 or b = 4' => `a in (1, 2) or b in (3, 4)`
      - merge skip not-equal and not-in: 'a =1 or a = 2 or b !=3 or c not in (1, 2)' => 'a in (1, 2) or b!=3 or c not in (1,2)`
    3. rewrite recursively.
    4. OrToIn is implemented in ExtractCommonFactorsRule. This rule will generate new exprs. OrToIn should apply on such generated exprs. for example `(a=1 and b=2) or (a=3 and b=4)` => `(a=1 or a=3) and (b=2 or b=4) and [(a=1 and b=2) or (a=3 and b=4)]` => `a in (1,3) and b in (2 ,4) and [(a=1 and b=2) or (a=3 and b=4)]`
    
    In addition, this pr add toString() for some Expr.
---
 .../org/apache/doris/analysis/BinaryPredicate.java |   7 ++
 .../apache/doris/analysis/CompoundPredicate.java   |  26 +++++
 .../org/apache/doris/analysis/LiteralExpr.java     |   5 +
 .../java/org/apache/doris/analysis/SlotRef.java    |  12 ++
 .../doris/rewrite/ExtractCommonFactorsRule.java    | 128 +++++++++++++++------
 .../org/apache/doris/analysis/SelectStmtTest.java  |  26 +++--
 .../org/apache/doris/planner/QueryPlanTest.java    |  57 +++++++--
 .../ExtractCommonFactorsRuleFunctionTest.java      |  12 +-
 .../data/performance_p0/redundant_conjuncts.out    |   2 +-
 9 files changed, 215 insertions(+), 60 deletions(-)

diff --git a/fe/fe-core/src/main/java/org/apache/doris/analysis/BinaryPredicate.java b/fe/fe-core/src/main/java/org/apache/doris/analysis/BinaryPredicate.java
index 6a0d1002e4..e1f24c4530 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/analysis/BinaryPredicate.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/analysis/BinaryPredicate.java
@@ -218,6 +218,13 @@ public class BinaryPredicate extends Predicate implements Writable {
         this.op = op;
     }
 
+    @Override
+    public String toString() {
+        StringBuilder builder = new StringBuilder();
+        builder.append(children.get(0)).append(" ").append(op).append(" ").append(children.get(1));
+        return builder.toString();
+    }
+
     @Override
     public Expr negate() {
         Operator newOp = null;
diff --git a/fe/fe-core/src/main/java/org/apache/doris/analysis/CompoundPredicate.java b/fe/fe-core/src/main/java/org/apache/doris/analysis/CompoundPredicate.java
index 7ce22bb732..f1e41ad0b4 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/analysis/CompoundPredicate.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/analysis/CompoundPredicate.java
@@ -213,6 +213,27 @@ public class CompoundPredicate extends Predicate {
         return conjunctivePred;
     }
 
+    /**
+     * Creates a disjunctive predicate from a list of exprs,
+     * reserve the expr order
+     */
+    public static Expr createDisjunctivePredicate(List<Expr> disjunctions) {
+        Expr result = null;
+        for (Expr expr : disjunctions) {
+            if (result == null) {
+                result = expr;
+                continue;
+            }
+            result = new CompoundPredicate(CompoundPredicate.Operator.OR, result, expr);
+        }
+        return result;
+    }
+
+    public static boolean isOr(Expr expr) {
+        return expr instanceof CompoundPredicate
+                && ((CompoundPredicate) expr).getOp() == Operator.OR;
+    }
+
     @Override
     public Expr getResultValue() throws AnalysisException {
         recursiveResetChildrenResult();
@@ -261,4 +282,9 @@ public class CompoundPredicate extends Predicate {
     public void finalizeImplForNereids() throws AnalysisException {
 
     }
+
+    @Override
+    public String toString() {
+        return toSqlImpl();
+    }
 }
diff --git a/fe/fe-core/src/main/java/org/apache/doris/analysis/LiteralExpr.java b/fe/fe-core/src/main/java/org/apache/doris/analysis/LiteralExpr.java
index 0528e33eed..022f1b41d6 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/analysis/LiteralExpr.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/analysis/LiteralExpr.java
@@ -262,4 +262,9 @@ public abstract class LiteralExpr extends Expr implements Comparable<LiteralExpr
     public void finalizeImplForNereids() throws AnalysisException {
 
     }
+
+    @Override
+    public String toString() {
+        return getStringValue();
+    }
 }
diff --git a/fe/fe-core/src/main/java/org/apache/doris/analysis/SlotRef.java b/fe/fe-core/src/main/java/org/apache/doris/analysis/SlotRef.java
index 2c9db0fdad..0650acfb28 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/analysis/SlotRef.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/analysis/SlotRef.java
@@ -484,4 +484,16 @@ public class SlotRef extends Expr {
     public void finalizeImplForNereids() throws AnalysisException {
 
     }
+
+    @Override
+    public String toString() {
+        StringBuilder builder = new StringBuilder();
+        if (tblName != null) {
+            builder.append(tblName).append(".");
+        }
+        if (label != null) {
+            builder.append(label);
+        }
+        return builder.toString();
+    }
 }
diff --git a/fe/fe-core/src/main/java/org/apache/doris/rewrite/ExtractCommonFactorsRule.java b/fe/fe-core/src/main/java/org/apache/doris/rewrite/ExtractCommonFactorsRule.java
index 74130d52af..d28fde13ea 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/rewrite/ExtractCommonFactorsRule.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/rewrite/ExtractCommonFactorsRule.java
@@ -43,11 +43,13 @@ import org.apache.logging.log4j.Logger;
 
 import java.util.ArrayList;
 import java.util.Arrays;
+import java.util.HashMap;
 import java.util.HashSet;
 import java.util.LinkedHashSet;
 import java.util.List;
 import java.util.Map;
 import java.util.Set;
+import java.util.stream.Collectors;
 
 /**
  * This rule extracts common predicate from multiple disjunctions when it is applied
@@ -113,7 +115,8 @@ public class ExtractCommonFactorsRule implements ExprRewriteRule {
      * 4. Construct new expr:
      * @return: a and b' and (b or (e and f))
      */
-    private Expr extractCommonFactors(List<List<Expr>> exprs, Analyzer analyzer, ExprRewriter.ClauseType clauseType) {
+    private Expr extractCommonFactors(List<List<Expr>> exprs, Analyzer analyzer, ExprRewriter.ClauseType clauseType)
+            throws AnalysisException {
         if (exprs.size() < 2) {
             return null;
         }
@@ -187,12 +190,19 @@ public class ExtractCommonFactorsRule implements ExprRewriteRule {
         }
         Expr result = null;
         if (CollectionUtils.isNotEmpty(commonFactorList)) {
+            commonFactorList = commonFactorList.stream().map(expr -> {
+                try {
+                    return apply(expr, analyzer, clauseType);
+                } catch (AnalysisException e) {
+                    throw new RuntimeException(e);
+                }
+            }).collect(Collectors.toList());
             result = new CompoundPredicate(CompoundPredicate.Operator.AND,
                     makeCompound(commonFactorList, CompoundPredicate.Operator.AND),
-                    makeCompoundRemaining(remainingOrClause, CompoundPredicate.Operator.OR));
+                    makeCompoundRemaining(remainingOrClause, CompoundPredicate.Operator.OR, analyzer, clauseType));
             result.setPrintSqlInParens(true);
         } else {
-            result = makeCompoundRemaining(remainingOrClause, CompoundPredicate.Operator.OR);
+            result = makeCompoundRemaining(remainingOrClause, CompoundPredicate.Operator.OR, analyzer, clauseType);
         }
         if (LOG.isDebugEnabled()) {
             LOG.debug("equal ors: " + result.toSql());
@@ -430,7 +440,8 @@ public class ExtractCommonFactorsRule implements ExprRewriteRule {
         return result;
     }
 
-    private Expr makeCompoundRemaining(List<Expr> exprs, CompoundPredicate.Operator op) {
+    private Expr makeCompoundRemaining(List<Expr> exprs, CompoundPredicate.Operator op,
+            Analyzer analyzer, ExprRewriter.ClauseType clauseType) throws AnalysisException {
         if (CollectionUtils.isEmpty(exprs)) {
             return null;
         }
@@ -441,7 +452,7 @@ public class ExtractCommonFactorsRule implements ExprRewriteRule {
         Expr rewritePredicate = null;
         // only OR will be rewrite to IN
         if (op == CompoundPredicate.Operator.OR) {
-            rewritePredicate = rewriteOrToIn(exprs);
+            rewritePredicate = rewriteOrToIn(exprs, analyzer, clauseType);
             // IF rewrite finished, rewritePredicate will not be null
             // IF not rewrite, do compoundPredicate
             if (rewritePredicate != null) {
@@ -457,60 +468,109 @@ public class ExtractCommonFactorsRule implements ExprRewriteRule {
         return result;
     }
 
-    private Expr rewriteOrToIn(List<Expr> exprs) {
+    private Expr rewriteOrToIn(List<Expr> exprs, Analyzer analyzer, ExprRewriter.ClauseType clauseType)
+            throws AnalysisException {
         // remainingOR  expr = BP IP
         InPredicate inPredicate = null;
-        boolean isOrToInAllowed = true;
-        Set<String> slotSet = new LinkedHashSet<>();
-
         int rewriteThreshold;
         if (ConnectContext.get() == null) {
             rewriteThreshold = 2;
         } else {
             rewriteThreshold = ConnectContext.get().getSessionVariable().getRewriteOrToInPredicateThreshold();
         }
+        List<Expr> notMergedExprs = Lists.newArrayList();
+        /**
+         * col1= 1 or col1=2 or col2=3 or col2=4 or col1 != 5 or col1 not in (2)
+         * ==>
+         * slotNameToMergeExprsMap:
+         *  {
+         *      col1:[col1=1, col1=2],
+         *      col2:[col2=3, col2=4]
+         *  }
+         * notMergedExprs: [col1 != 5, col1 not in (2)]
+         */
+        Map<String, List<Expr>> slotNameToMergeExprsMap = new HashMap<>();
+        /*
+        slotNameForMerge is keys of slotNameToMergeExprsMap, but reserves slot orders in original expr.
+        To reserve orders, we can get stable output, and hence good for unit/regression test.
+         */
+        List<String> slotNameForMerge = Lists.newArrayList();
 
         for (int i = 0; i < exprs.size(); i++) {
             Expr predicate = exprs.get(i);
-            if (!(predicate instanceof BinaryPredicate) && !(predicate instanceof InPredicate)) {
-                isOrToInAllowed = false;
-                break;
+            if (predicate instanceof CompoundPredicate
+                    && ((CompoundPredicate) predicate).getOp() == Operator.AND) {
+                CompoundPredicate and = (CompoundPredicate) predicate;
+                Expr left = and.getChild(0);
+                if (left instanceof CompoundPredicate) {
+                    left = apply(and.getChild(0), analyzer, clauseType);
+                    if (CompoundPredicate.isOr(left)) {
+                        left.setPrintSqlInParens(true);
+                    }
+                }
+                Expr right = and.getChild(1);
+                if (right instanceof CompoundPredicate) {
+                    right = apply(and.getChild(1), analyzer, clauseType);
+                    if (CompoundPredicate.isOr(right)) {
+                        right.setPrintSqlInParens(true);
+                    }
+                }
+                notMergedExprs.add(new CompoundPredicate(Operator.AND, left, right));
+            } else if (!(predicate instanceof BinaryPredicate) && !(predicate instanceof InPredicate)) {
+                notMergedExprs.add(predicate);
             } else if (!(predicate.getChild(0) instanceof SlotRef)) {
-                isOrToInAllowed = false;
-                break;
+                notMergedExprs.add(predicate);
             } else if (!(predicate.getChild(1) instanceof LiteralExpr)) {
-                isOrToInAllowed = false;
-                break;
+                notMergedExprs.add(predicate);
             } else if (predicate instanceof BinaryPredicate
                     && ((BinaryPredicate) predicate).getOp() != BinaryPredicate.Operator.EQ) {
-                isOrToInAllowed = false;
-                break;
+                notMergedExprs.add(predicate);
+            } else if (predicate instanceof InPredicate
+                    && ((InPredicate) predicate).isNotIn()) {
+                notMergedExprs.add(predicate);
             } else {
                 TableName tableName = ((SlotRef) predicate.getChild(0)).getTableName();
+                String columnWithTable;
                 if (tableName != null) {
                     String tblName = tableName.toString();
-                    String columnWithTable = tblName + "." + ((SlotRef) predicate.getChild(0)).getColumnName();
-                    slotSet.add(columnWithTable);
+                    columnWithTable = tblName + "." + ((SlotRef) predicate.getChild(0)).getColumnName();
                 } else {
-                    slotSet.add(((SlotRef) predicate.getChild(0)).getColumnName());
+                    columnWithTable = ((SlotRef) predicate.getChild(0)).getColumnName();
                 }
+                slotNameToMergeExprsMap.computeIfAbsent(columnWithTable, key -> {
+                    slotNameForMerge.add(columnWithTable);
+                    return Lists.newArrayList();
+                });
+
+                slotNameToMergeExprsMap.get(columnWithTable).add(predicate);
             }
         }
-
-        // isOrToInAllowed : true, means can rewrite
-        // slotSet.size : nums of columnName in exprs, should be 1
-        if (isOrToInAllowed && slotSet.size() == 1) {
-            if (exprs.size() < rewriteThreshold) {
-                return null;
+        Expr notMerged = null;
+        if (!notMergedExprs.isEmpty()) {
+            notMerged = CompoundPredicate.createDisjunctivePredicate(notMergedExprs);
+        }
+        List<Expr> rewritten = Lists.newArrayList();
+        if (!slotNameToMergeExprsMap.isEmpty()) {
+            for (String columnNameWithTable : slotNameForMerge) {
+                List<Expr> toMerge = slotNameToMergeExprsMap.get(columnNameWithTable);
+                if (toMerge.size() < rewriteThreshold) {
+                    rewritten.addAll(toMerge);
+                } else {
+                    List<Expr> deduplicationExprs = getDeduplicationList(toMerge);
+                    inPredicate = new InPredicate(deduplicationExprs.get(0),
+                            deduplicationExprs.subList(1, deduplicationExprs.size()), false);
+                    rewritten.add(inPredicate);
+                }
             }
-
-            // get deduplication list
-            List<Expr> deduplicationExprs = getDeduplicationList(exprs);
-            inPredicate = new InPredicate(deduplicationExprs.get(0),
-                    deduplicationExprs.subList(1, deduplicationExprs.size()), false);
         }
-
-        return inPredicate;
+        if (rewritten.isEmpty()) {
+            return notMerged;
+        } else {
+            if (notMerged != null) {
+                rewritten.add(notMerged);
+            }
+            return CompoundPredicate.createDisjunctivePredicate(rewritten);
+        }
     }
 
     public List<Expr> getDeduplicationList(List<Expr> exprs) {
diff --git a/fe/fe-core/src/test/java/org/apache/doris/analysis/SelectStmtTest.java b/fe/fe-core/src/test/java/org/apache/doris/analysis/SelectStmtTest.java
index 7e0137e6bd..39dc947fe1 100755
--- a/fe/fe-core/src/test/java/org/apache/doris/analysis/SelectStmtTest.java
+++ b/fe/fe-core/src/test/java/org/apache/doris/analysis/SelectStmtTest.java
@@ -289,7 +289,6 @@ public class SelectStmtTest {
         String betweenExpanded3 = "`t1`.`k4` >= 50 AND `t1`.`k4` <= 250";
 
         String rewrittenSql = stmt.toSql();
-        System.out.println(rewrittenSql);
         Assert.assertTrue(rewrittenSql.contains(commonExpr1));
         Assert.assertEquals(rewrittenSql.indexOf(commonExpr1), rewrittenSql.lastIndexOf(commonExpr1));
         Assert.assertTrue(rewrittenSql.contains(commonExpr2));
@@ -330,13 +329,18 @@ public class SelectStmtTest {
                 + ")";
         SelectStmt stmt2 = (SelectStmt) UtFrameUtils.parseAndAnalyzeStmt(sql2, ctx);
         stmt2.rewriteExprs(new Analyzer(ctx.getEnv(), ctx).getExprRewriter());
-        String fragment3 = "((`t1`.`k1` = `t2`.`k3` AND `t2`.`k2` = 'United States'"
-                + " AND `t2`.`k3` IN ('CO', 'IL', 'MN') "
-                + "AND `t1`.`k4` >= 100 AND `t1`.`k4` <= 200) "
-                + "OR (`t1`.`k1` = `t2`.`k1` AND `t2`.`k2` = 'United States1' "
-                + "AND `t2`.`k3` IN ('OH', 'MT', 'NM') AND `t1`.`k4` >= 150 AND `t1`.`k4` <= 300) "
-                + "OR (`t1`.`k1` = `t2`.`k1` AND `t2`.`k2` = 'United States' AND `t2`.`k3` IN ('TX', 'MO', 'MI') "
-                + "AND `t1`.`k4` >= 50 AND `t1`.`k4` <= 250))";
+        String fragment3 =
+                "(((`t1`.`k4` >= 50 AND `t1`.`k4` <= 300) AND `t2`.`k2` IN ('United States', 'United States1') "
+                        + "AND `t2`.`k3` IN ('CO', 'IL', 'MN', 'OH', 'MT', 'NM', 'TX', 'MO', 'MI')) "
+                        + "AND `t1`.`k1` = `t2`.`k3` AND `t2`.`k2` = 'United States' "
+                        + "AND `t2`.`k3` IN ('CO', 'IL', 'MN') AND `t1`.`k4` >= 100 AND `t1`.`k4` <= 200 "
+                        + "OR "
+                        + "`t1`.`k1` = `t2`.`k1` AND `t2`.`k2` = 'United States1' "
+                        + "AND `t2`.`k3` IN ('OH', 'MT', 'NM') AND `t1`.`k4` >= 150 AND `t1`.`k4` <= 300 "
+                        + "OR "
+                        + "`t1`.`k1` = `t2`.`k1` AND `t2`.`k2` = 'United States' "
+                        + "AND `t2`.`k3` IN ('TX', 'MO', 'MI') "
+                        + "AND `t1`.`k4` >= 50 AND `t1`.`k4` <= 250)";
         Assert.assertTrue(stmt2.toSql().contains(fragment3));
 
         String sql3 = "select\n"
@@ -396,7 +400,7 @@ public class SelectStmtTest {
         SelectStmt stmt7 = (SelectStmt) UtFrameUtils.parseAndAnalyzeStmt(sql7, ctx);
         stmt7.rewriteExprs(new Analyzer(ctx.getEnv(), ctx).getExprRewriter());
         Assert.assertTrue(stmt7.toSql()
-                .contains("`t2`.`k1` IS NOT NULL OR (`t1`.`k1` IS NOT NULL " + "AND `t1`.`k2` IS NOT NULL)"));
+                .contains("`t2`.`k1` IS NOT NULL OR `t1`.`k1` IS NOT NULL AND `t1`.`k2` IS NOT NULL"));
 
         String sql8 = "select\n"
                 + "   avg(t1.k4)\n"
@@ -408,13 +412,13 @@ public class SelectStmtTest {
         SelectStmt stmt8 = (SelectStmt) UtFrameUtils.parseAndAnalyzeStmt(sql8, ctx);
         stmt8.rewriteExprs(new Analyzer(ctx.getEnv(), ctx).getExprRewriter());
         Assert.assertTrue(stmt8.toSql()
-                .contains("`t2`.`k1` IS NOT NULL AND `t1`.`k1` IS NOT NULL" + " AND `t1`.`k1` IS NOT NULL"));
+                .contains("`t2`.`k1` IS NOT NULL AND `t1`.`k1` IS NOT NULL AND `t1`.`k1` IS NOT NULL"));
 
         String sql9 = "select * from db1.tbl1 where (k1='shutdown' and k4<1) or (k1='switchOff' and k4>=1)";
         SelectStmt stmt9 = (SelectStmt) UtFrameUtils.parseAndAnalyzeStmt(sql9, ctx);
         stmt9.rewriteExprs(new Analyzer(ctx.getEnv(), ctx).getExprRewriter());
         Assert.assertTrue(
-                stmt9.toSql().contains("(`k1` = 'shutdown' AND `k4` < 1)" + " OR (`k1` = 'switchOff' AND `k4` >= 1)"));
+                stmt9.toSql().contains("`k1` = 'shutdown' AND `k4` < 1 OR `k1` = 'switchOff' AND `k4` >= 1"));
     }
 
     @Test
diff --git a/fe/fe-core/src/test/java/org/apache/doris/planner/QueryPlanTest.java b/fe/fe-core/src/test/java/org/apache/doris/planner/QueryPlanTest.java
index 9d8e560c11..5b403f95b4 100644
--- a/fe/fe-core/src/test/java/org/apache/doris/planner/QueryPlanTest.java
+++ b/fe/fe-core/src/test/java/org/apache/doris/planner/QueryPlanTest.java
@@ -2220,40 +2220,81 @@ public class QueryPlanTest extends TestWithFeService {
         Assert.assertTrue(explainString.contains("PREAGGREGATION: ON"));
     }
 
+    /*
+    NOTE:
+    explainString.contains("PREDICATES: xxx\n")
+    add '\n' at the end of line to ensure there are no other predicates
+     */
     @Test
     public void testRewriteOrToIn() throws Exception {
         connectContext.setDatabase("default_cluster:test");
         String sql = "SELECT * from test1 where query_time = 1 or query_time = 2 or query_time in (3, 4)";
         String explainString = UtFrameUtils.getSQLPlanOrErrorMsg(connectContext, "EXPLAIN " + sql);
-        Assert.assertTrue(explainString.contains("PREDICATES: `query_time` IN (1, 2, 3, 4)"));
+        Assert.assertTrue(explainString.contains("PREDICATES: `query_time` IN (1, 2, 3, 4)\n"));
 
         sql = "SELECT * from test1 where (query_time = 1 or query_time = 2) and query_time in (3, 4)";
         explainString = UtFrameUtils.getSQLPlanOrErrorMsg(connectContext, "EXPLAIN " + sql);
-        Assert.assertTrue(explainString.contains("PREDICATES: `query_time` IN (1, 2), `query_time` IN (3, 4)"));
+        Assert.assertTrue(explainString.contains("PREDICATES: `query_time` IN (1, 2), `query_time` IN (3, 4)\n"));
 
         sql = "SELECT * from test1 where (query_time = 1 or query_time = 2 or scan_bytes = 2) and scan_bytes in (2, 3)";
         explainString = UtFrameUtils.getSQLPlanOrErrorMsg(connectContext, "EXPLAIN " + sql);
-        Assert.assertTrue(explainString.contains("PREDICATES: (`query_time` = 1 OR `query_time` = 2 OR `scan_bytes` = 2), `scan_bytes` IN (2, 3)"));
+        Assert.assertTrue(explainString.contains("PREDICATES: `query_time` IN (1, 2) OR `scan_bytes` = 2, `scan_bytes` IN (2, 3)\n"));
 
         sql = "SELECT * from test1 where (query_time = 1 or query_time = 2) and (scan_bytes = 2 or scan_bytes = 3)";
         explainString = UtFrameUtils.getSQLPlanOrErrorMsg(connectContext, "EXPLAIN " + sql);
-        Assert.assertTrue(explainString.contains("PREDICATES: `query_time` IN (1, 2), `scan_bytes` IN (2, 3)"));
+        Assert.assertTrue(explainString.contains("PREDICATES: `query_time` IN (1, 2), `scan_bytes` IN (2, 3)\n"));
 
         sql = "SELECT * from test1 where query_time = 1 or query_time = 2 or query_time = 3 or query_time = 1";
         explainString = UtFrameUtils.getSQLPlanOrErrorMsg(connectContext, "EXPLAIN " + sql);
-        Assert.assertTrue(explainString.contains("PREDICATES: `query_time` IN (1, 2, 3)"));
+        Assert.assertTrue(explainString.contains("PREDICATES: `query_time` IN (1, 2, 3)\n"));
 
         sql = "SELECT * from test1 where query_time = 1 or query_time = 2 or query_time in (3, 2)";
         explainString = UtFrameUtils.getSQLPlanOrErrorMsg(connectContext, "EXPLAIN " + sql);
-        Assert.assertTrue(explainString.contains("PREDICATES: `query_time` IN (1, 2, 3)"));
+        Assert.assertTrue(explainString.contains("PREDICATES: `query_time` IN (1, 2, 3)\n"));
 
         connectContext.getSessionVariable().setRewriteOrToInPredicateThreshold(100);
         sql = "SELECT * from test1 where query_time = 1 or query_time = 2 or query_time in (3, 4)";
         explainString = UtFrameUtils.getSQLPlanOrErrorMsg(connectContext, "EXPLAIN " + sql);
-        Assert.assertTrue(explainString.contains("PREDICATES: (`query_time` = 1 OR `query_time` = 2 OR `query_time` IN (3, 4))"));
+        Assert.assertTrue(explainString.contains("PREDICATES: `query_time` = 1 OR `query_time` = 2 OR `query_time` IN (3, 4)\n"));
+        connectContext.getSessionVariable().setRewriteOrToInPredicateThreshold(2);
 
         sql = "SELECT * from test1 where (query_time = 1 or query_time = 2) and query_time in (3, 4)";
         explainString = UtFrameUtils.getSQLPlanOrErrorMsg(connectContext, "EXPLAIN " + sql);
-        Assert.assertTrue(explainString.contains("PREDICATES: (`query_time` = 1 OR `query_time` = 2), `query_time` IN (3, 4)"));
+        Assert.assertTrue(explainString.contains("PREDICATES: `query_time` IN (1, 2), `query_time` IN (3, 4)\n"));
+
+        //test we can handle `!=` and `not in`
+        sql = "select * from test1 where (query_time = 1 or query_time = 2 or query_time!= 3 or query_time not in (5, 6))";
+        explainString = UtFrameUtils.getSQLPlanOrErrorMsg(connectContext, "EXPLAIN " + sql);
+        Assert.assertTrue(explainString.contains("PREDICATES: `query_time` IN (1, 2) OR `query_time` != 3 OR `query_time` NOT IN (5, 6)\n"));
+
+        //test we can handle merge 2 or more columns
+        sql = "select * from test1 where (query_time = 1 or query_time = 2 or scan_rows = 3 or scan_rows = 4)";
+        explainString = UtFrameUtils.getSQLPlanOrErrorMsg(connectContext, "EXPLAIN " + sql);
+        Assert.assertTrue(explainString.contains("PREDICATES: `query_time` IN (1, 2) OR `scan_rows` IN (3, 4)"));
+
+        //merge in-pred or in-pred
+        sql = "select * from test1 where (query_time = 1 or query_time = 2 or query_time = 3 or query_time = 4)";
+        explainString = UtFrameUtils.getSQLPlanOrErrorMsg(connectContext, "EXPLAIN " + sql);
+        Assert.assertTrue(explainString.contains("PREDICATES: `query_time` IN (1, 2, 3, 4)\n"));
+
+        //rewrite recursively
+        sql = "select * from test1 "
+                + "where query_id=client_ip "
+                + "      and (stmt_id=1 or stmt_id=2 or stmt_id=3 "
+                + "           or (user='abc' and (state = 'a' or state='b' or state in ('c', 'd'))))"
+                + "      or (db not in ('x', 'y')) ";
+        explainString = UtFrameUtils.getSQLPlanOrErrorMsg(connectContext, "EXPLAIN " + sql);
+        Assert.assertTrue(explainString.contains(
+                "PREDICATES: `query_id` = `client_ip` "
+                        + "AND (`stmt_id` IN (1, 2, 3) OR `user` = 'abc' AND `state` IN ('a', 'b', 'c', 'd')) "
+                        + "OR (`db` NOT IN ('x', 'y'))\n"));
+
+        //ExtractCommonFactorsRule may generate more expr, test the rewriteOrToIn applied on generated exprs
+        sql = "select * from test1 where (stmt_id=1 and state='a') or (stmt_id=2 and state='b')";
+        explainString = UtFrameUtils.getSQLPlanOrErrorMsg(connectContext, "EXPLAIN " + sql);
+        Assert.assertTrue(explainString.contains(
+                "PREDICATES: `state` IN ('a', 'b'), `stmt_id` IN (1, 2),"
+                        + " `stmt_id` = 1 AND `state` = 'a' OR `stmt_id` = 2 AND `state` = 'b'\n"
+        ));
     }
 }
diff --git a/fe/fe-core/src/test/java/org/apache/doris/rewrite/ExtractCommonFactorsRuleFunctionTest.java b/fe/fe-core/src/test/java/org/apache/doris/rewrite/ExtractCommonFactorsRuleFunctionTest.java
index 2f994be85b..5458f30043 100644
--- a/fe/fe-core/src/test/java/org/apache/doris/rewrite/ExtractCommonFactorsRuleFunctionTest.java
+++ b/fe/fe-core/src/test/java/org/apache/doris/rewrite/ExtractCommonFactorsRuleFunctionTest.java
@@ -100,15 +100,15 @@ public class ExtractCommonFactorsRuleFunctionTest {
     public void testWideCommonFactorsWithOrPredicate() throws Exception {
         String query = "select * from tb1 where tb1.k1 > 1000 or tb1.k1 < 200 or tb1.k1 = 300";
         String planString = dorisAssert.query(query).explainQuery();
-        Assert.assertTrue(planString.contains("PREDICATES: (`tb1`.`k1` > 1000 OR `tb1`.`k1` < 200 OR `tb1`.`k1` = 300)"));
+        Assert.assertTrue(planString.contains("PREDICATES: `tb1`.`k1` = 300 OR `tb1`.`k1` > 1000 OR `tb1`.`k1` < 200"));
     }
 
     @Test
     public void testWideCommonFactorsWithEqualPredicate() throws Exception {
         String query = "select * from tb1, tb2 where (tb1.k1=1 and tb2.k1=1) or (tb1.k1 =2 and tb2.k1=2)";
         String planString = dorisAssert.query(query).explainQuery();
-        Assert.assertTrue(planString.contains("(`tb1`.`k1` = 1 OR `tb1`.`k1` = 2)"));
-        Assert.assertTrue(planString.contains("(`tb2`.`k1` = 1 OR `tb2`.`k1` = 2)"));
+        Assert.assertTrue(planString.contains("`tb1`.`k1` IN (1, 2)"));
+        Assert.assertTrue(planString.contains("`tb2`.`k1` IN (1, 2)"));
         Assert.assertTrue(planString.contains("NESTED LOOP JOIN"));
     }
 
@@ -259,10 +259,10 @@ public class ExtractCommonFactorsRuleFunctionTest {
         Assert.assertTrue(planString.contains("`l_partkey` = `p_partkey`"));
         Assert.assertTrue(planString.contains("`l_shipmode` IN ('AIR', 'AIR REG')"));
         Assert.assertTrue(planString.contains("`l_shipinstruct` = 'DELIVER IN PERSON'"));
-        Assert.assertTrue(planString.contains("((`l_quantity` >= 9 AND `l_quantity` <= 19) "
-                + "OR (`l_quantity` >= 20 AND `l_quantity` <= 36))"));
+        Assert.assertTrue(planString.contains("`l_quantity` >= 9 AND `l_quantity` <= 19 "
+                + "OR `l_quantity` >= 20 AND `l_quantity` <= 36"));
         Assert.assertTrue(planString.contains("`p_size` >= 1"));
-        Assert.assertTrue(planString.contains("(`p_brand` = 'Brand#11' OR `p_brand` = 'Brand#21' OR `p_brand` = 'Brand#32')"));
+        Assert.assertTrue(planString.contains("`p_brand` IN ('Brand#11', 'Brand#21', 'Brand#32')"));
         Assert.assertTrue(planString.contains("`p_size` <= 15"));
         Assert.assertTrue(planString.contains("`p_container` IN ('SM CASE', 'SM BOX', 'SM PACK', 'SM PKG', 'MED BAG', "
                 + "'MED BOX', 'MED PKG', 'MED PACK', 'LG CASE', 'LG BOX', 'LG PACK', 'LG PKG')"));
diff --git a/regression-test/data/performance_p0/redundant_conjuncts.out b/regression-test/data/performance_p0/redundant_conjuncts.out
index 98178f31aa..3baa5b3d93 100644
--- a/regression-test/data/performance_p0/redundant_conjuncts.out
+++ b/regression-test/data/performance_p0/redundant_conjuncts.out
@@ -23,7 +23,7 @@ PLAN FRAGMENT 0
 
   0:VOlapScanNode
      TABLE: default_cluster:regression_test_performance_p0.redundant_conjuncts(redundant_conjuncts), PREAGGREGATION: OFF. Reason: No AggregateInfo
-     PREDICATES: (`k1` = 1 OR `k1` = 2)
+     PREDICATES: `k1` = 1 OR `k1` = 2
      partitions=0/1, tablets=0/0, tabletList=
      cardinality=0, avgRowSize=8.0, numNodes=1
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@doris.apache.org
For additional commands, e-mail: commits-help@doris.apache.org