You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@doris.apache.org by mo...@apache.org on 2020/04/09 13:57:53 UTC
[incubator-doris] branch master updated: [Query] Optimize where
clause by extracting the common predicate in the OR compound predicate.
(#3278)
This is an automated email from the ASF dual-hosted git repository.
morningman pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-doris.git
The following commit(s) were added to refs/heads/master by this push:
new 8699bb7 [Query] Optimize where clause by extracting the common predicate in the OR compound predicate. (#3278)
8699bb7 is described below
commit 8699bb7bd47033c6dd5db3e9021ef3cc5d324f63
Author: yangzhg <78...@qq.com>
AuthorDate: Thu Apr 9 21:57:45 2020 +0800
[Query] Optimize where clause by extracting the common predicate in the OR compound predicate. (#3278)
Queries like below cannot finish in a acceptable time, `store_sales` has 2800w rows, `customer_address` has 5w rows, for now Doris will create only one cross join node to execute this sql,
the time of eval the where clause is about 200-300 ns, the total count of eval will be 2800w * 5w, this is extremely large, and this will cost 2800w * 5w * 250 ns = 4 billion seconds;
```
select avg(ss_quantity)
,avg(ss_ext_sales_price)
,avg(ss_ext_wholesale_cost)
,sum(ss_ext_wholesale_cost)
from store_sales, customer_address
where ((ss_addr_sk = ca_address_sk
and ca_country = 'United States'
and ca_state in ('CO', 'IL', 'MN')
and ss_net_profit between 100 and 200
) or
(ss_addr_sk = ca_address_sk
and ca_country = 'United States'
and ca_state in ('OH', 'MT', 'NM')
and ss_net_profit between 150 and 300
) or
(ss_addr_sk = ca_address_sk
and ca_country = 'United States'
and ca_state in ('TX', 'MO', 'MI')
and ss_net_profit between 50 and 250
))
```
but this sql can be rewrite to
```
select avg(ss_quantity)
,avg(ss_ext_sales_price)
,avg(ss_ext_wholesale_cost)
,sum(ss_ext_wholesale_cost)
from store_sales, customer_address
where ss_addr_sk = ca_address_sk
and ca_country = 'United States' and (((ca_state in ('CO', 'IL', 'MN')
and ss_net_profit between 100 and 200
) or
(ca_state in ('OH', 'MT', 'NM')
and ss_net_profit between 150 and 300
) or
(ca_state in ('TX', 'MO', 'MI')
and ss_net_profit between 50 and 250
))
)
```
there for we can do a hash join first and then use
```
(((ca_state in ('CO', 'IL', 'MN')
and ss_net_profit between 100 and 200
) or
(ca_state in ('OH', 'MT', 'NM')
and ss_net_profit between 150 and 300
) or
(ca_state in ('TX', 'MO', 'MI')
and ss_net_profit between 50 and 250
))
)
```
to filter the value,
in TPCDS 10g dataset, the rewritten sql only cost about 1 seconds.
---
.../java/org/apache/doris/analysis/SelectStmt.java | 131 ++++++++++++++
.../org/apache/doris/analysis/SelectStmtTest.java | 198 ++++++++++++++++++++-
2 files changed, 324 insertions(+), 5 deletions(-)
diff --git a/fe/src/main/java/org/apache/doris/analysis/SelectStmt.java b/fe/src/main/java/org/apache/doris/analysis/SelectStmt.java
index ebb5c50..2e84409 100644
--- a/fe/src/main/java/org/apache/doris/analysis/SelectStmt.java
+++ b/fe/src/main/java/org/apache/doris/analysis/SelectStmt.java
@@ -54,6 +54,7 @@ import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
+import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
@@ -495,6 +496,10 @@ public class SelectStmt extends QueryStmt {
}
private void whereClauseRewrite() {
+ Expr deDuplicatedWhere = deduplicateOrs(whereClause);
+ if (deDuplicatedWhere != null) {
+ whereClause = deDuplicatedWhere;
+ }
if (whereClause instanceof IntLiteral) {
if (((IntLiteral) whereClause).getLongValue() == 0) {
whereClause = new BoolLiteral(false);
@@ -505,6 +510,132 @@ public class SelectStmt extends QueryStmt {
}
/**
+ * this function only process (a and b and c) or (d and e and f) like clause,
+ * this function will extract this to [[a, b, c], [d, e, f]]
+ */
+ private List<List<Expr>> extractDuplicateOrs(CompoundPredicate expr) {
+ List<List<Expr>> orExprs = new ArrayList<>();
+ for (Expr child : expr.getChildren()) {
+ if (child instanceof CompoundPredicate) {
+ CompoundPredicate childCp = (CompoundPredicate) child;
+ if (childCp.getOp() == CompoundPredicate.Operator.OR) {
+ orExprs.addAll(extractDuplicateOrs(childCp));
+ continue;
+ } else if (childCp.getOp() == CompoundPredicate.Operator.AND) {
+ orExprs.add(flatAndExpr(child));
+ continue;
+ }
+ }
+ orExprs.add(Arrays.asList(child));
+ }
+ return orExprs;
+ }
+
+ /**
+ * This function attempts to apply the inverse OR distributive law:
+ * ((A AND B) OR (A AND C)) => (A AND (B OR C))
+ * That is, locate OR clauses in which every subclause contains an
+ * identical term, and pull out the duplicated terms.
+ */
+ private Expr deduplicateOrs(Expr expr) {
+ if (expr instanceof CompoundPredicate && ((CompoundPredicate) expr).getOp() == CompoundPredicate.Operator.OR) {
+ Expr rewritedExpr = processDuplicateOrs(extractDuplicateOrs((CompoundPredicate) expr));
+ if (rewritedExpr != null) {
+ return rewritedExpr;
+ }
+ } else {
+ for (int i = 0; i < expr.getChildren().size(); i++) {
+ Expr rewritedExpr = deduplicateOrs(expr.getChild(i));
+ if (rewritedExpr != null) {
+ expr.setChild(i, rewritedExpr);
+ }
+ }
+ }
+ return expr;
+ }
+
+ /**
+ * try to flat and , a and b and c => [a, b, c]
+ */
+ private List<Expr> flatAndExpr(Expr expr) {
+ List<Expr> andExprs = new ArrayList<>();
+ if (expr instanceof CompoundPredicate && ((CompoundPredicate) expr).getOp() == CompoundPredicate.Operator.AND) {
+ andExprs.addAll(flatAndExpr(expr.getChild(0)));
+ andExprs.addAll(flatAndExpr(expr.getChild(1)));
+ } else {
+ andExprs.add(expr);
+ }
+ return andExprs;
+ }
+
+ /**
+ * the input is a list of list, the inner list is and connected exprs, the outer list is or connected
+ * for example clause (a and b and c) or (a and e and f) after extractDuplicateOrs will be [[a, b, c], [a, e, f]]
+ * this is the input of this function, first step is deduplicate [[a, b, c], [a, e, f]] => [[a], [b, c], [e, f]]
+ * then rebuild the expr to a and ((b and c) or (e and f))
+ */
+ private Expr processDuplicateOrs(List<List<Expr>> exprs) {
+ if (exprs.size() < 2) {
+ return null;
+ }
+ // 1. remove duplicated elements [[a,a], [a, b], [a,b]] => [[a], [a,b]]
+ Set<Set<Expr>> set = new LinkedHashSet<>();
+ for (List<Expr> ex : exprs) {
+ Set<Expr> es = new LinkedHashSet<>();
+ es.addAll(ex);
+ set.add(es);
+ }
+ List<List<Expr>> clearExprs = new ArrayList<>();
+ for (Set<Expr> es : set) {
+ List<Expr> el = new ArrayList<>();
+ el.addAll(es);
+ clearExprs.add(el);
+ }
+ if (clearExprs.size() == 1) {
+ return makeCompound(clearExprs.get(0), CompoundPredicate.Operator.AND);
+ }
+ // 2. find duplcate cross the clause
+ List<Expr> cloneExprs = new ArrayList<>(clearExprs.get(0));
+ for (int i = 1; i < clearExprs.size(); ++i) {
+ cloneExprs.retainAll(clearExprs.get(i));
+ }
+ List<Expr> temp = new ArrayList<>();
+ if (CollectionUtils.isNotEmpty(cloneExprs)) {
+ temp.add(makeCompound(cloneExprs, CompoundPredicate.Operator.AND));
+ }
+
+ for (List<Expr> exprList : clearExprs) {
+ exprList.removeAll(cloneExprs);
+ temp.add(makeCompound(exprList, CompoundPredicate.Operator.AND));
+ }
+
+ // rebuild CompoundPredicate if found duplicate predicate will build (predcate) and (.. or ..) predicate in
+ // step 1: will build (.. or ..)
+ Expr result = CollectionUtils.isNotEmpty(cloneExprs) ? new CompoundPredicate(CompoundPredicate.Operator.AND,
+ temp.get(0), makeCompound(temp.subList(1, temp.size()), CompoundPredicate.Operator.OR))
+ : makeCompound(temp, CompoundPredicate.Operator.OR);
+ LOG.debug("rewrite ors: " + result.toSql());
+ return result;
+ }
+
+ /**
+ * Rebuild CompoundPredicate, [a, e, f] AND => a and e and f
+ */
+ private Expr makeCompound(List<Expr> exprs, CompoundPredicate.Operator op) {
+ if (CollectionUtils.isEmpty(exprs)) {
+ return null;
+ }
+ if (exprs.size() == 1) {
+ return exprs.get(0);
+ }
+ CompoundPredicate result = new CompoundPredicate(op, exprs.get(0), exprs.get(1));
+ for (int i = 2; i < exprs.size(); ++i) {
+ result = new CompoundPredicate(op, result.clone(), exprs.get(i));
+ }
+ return result;
+ }
+
+ /**
* Generates and registers !empty() predicates to filter out empty collections directly
* in the parent scan of collection table refs. This is a performance optimization to
* avoid the expensive processing of empty collections inside a subplan that would
diff --git a/fe/src/test/java/org/apache/doris/analysis/SelectStmtTest.java b/fe/src/test/java/org/apache/doris/analysis/SelectStmtTest.java
index 171776b..92db5b1 100644
--- a/fe/src/test/java/org/apache/doris/analysis/SelectStmtTest.java
+++ b/fe/src/test/java/org/apache/doris/analysis/SelectStmtTest.java
@@ -20,6 +20,7 @@ package org.apache.doris.analysis;
import org.apache.doris.common.AnalysisException;
import org.apache.doris.qe.ConnectContext;
import org.apache.doris.rewrite.ExprRewriter;
+import org.apache.doris.thrift.TPrimitiveType;
import org.apache.doris.utframe.DorisAssert;
import org.apache.doris.utframe.UtFrameUtils;
import org.junit.AfterClass;
@@ -29,6 +30,7 @@ import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.ExpectedException;
+import java.io.IOException;
import java.util.UUID;
public class SelectStmtTest {
@@ -89,10 +91,196 @@ public class SelectStmtTest {
"FROM db1.tbl1;";
SelectStmt stmt = (SelectStmt) UtFrameUtils.parseAndAnalyzeStmt(sql1, ctx);
stmt.rewriteExprs(new Analyzer(ctx.getCatalog(), ctx).getExprRewriter());
- Assert.assertEquals("SELECT CASE WHEN `$a$1`.`$c$1` > `k4` THEN `$a$2`.`$c$2` ELSE `$a$3`.`$c$3` END" +
- " AS `kk4` FROM `default_cluster:db1`.`tbl1` (SELECT count(*) / 2.0 AS `count(*) / 2.0` FROM " +
- "`default_cluster:db1`.`tbl1`) $a$1 (SELECT avg(`k4`) AS `avg(``k4``)` FROM" +
- " `default_cluster:db1`.`tbl1`) $a$2 (SELECT sum(`k4`) AS `sum(``k4``)` " +
- "FROM `default_cluster:db1`.`tbl1`) $a$3", stmt.toSql());
+ Assert.assertTrue(stmt.toSql().contains("`$a$1`.`$c$1` > `k4` THEN `$a$2`.`$c$2` ELSE `$a$3`.`$c$3`"));
+ }
+
+ @Test
+ public void testDeduplicateOrs() throws Exception {
+ ConnectContext ctx = UtFrameUtils.createDefaultCtx();
+ String sql = "select\n" +
+ " avg(t1.k4)\n" +
+ "from\n" +
+ " db1.tbl1 t1,\n" +
+ " db1.tbl1 t2,\n" +
+ " db1.tbl1 t3,\n" +
+ " db1.tbl1 t4,\n" +
+ " db1.tbl1 t5,\n" +
+ " db1.tbl1 t6\n" +
+ "where\n" +
+ " t2.k1 = t1.k1\n" +
+ " and t1.k2 = t6.k2\n" +
+ " and t6.k4 = 2001\n" +
+ " and(\n" +
+ " (\n" +
+ " t1.k2 = t4.k2\n" +
+ " and t3.k3 = t1.k3\n" +
+ " and t3.k1 = 'D'\n" +
+ " and t4.k3 = '2 yr Degree'\n" +
+ " and t1.k4 between 100.00\n" +
+ " and 150.00\n" +
+ " and t4.k4 = 3\n" +
+ " )\n" +
+ " or (\n" +
+ " t1.k2 = t4.k2\n" +
+ " and t3.k3 = t1.k3\n" +
+ " and t3.k1 = 'S'\n" +
+ " and t4.k3 = 'Secondary'\n" +
+ " and t1.k4 between 50.00\n" +
+ " and 100.00\n" +
+ " and t4.k4 = 1\n" +
+ " )\n" +
+ " or (\n" +
+ " t1.k2 = t4.k2\n" +
+ " and t3.k3 = t1.k3\n" +
+ " and t3.k1 = 'W'\n" +
+ " and t4.k3 = 'Advanced Degree'\n" +
+ " and t1.k4 between 150.00\n" +
+ " and 200.00\n" +
+ " and t4.k4 = 1\n" +
+ " )\n" +
+ " )\n" +
+ " and(\n" +
+ " (\n" +
+ " t1.k1 = t5.k1\n" +
+ " and t5.k2 = 'United States'\n" +
+ " and t5.k3 in ('CO', 'IL', 'MN')\n" +
+ " and t1.k4 between 100\n" +
+ " and 200\n" +
+ " )\n" +
+ " or (\n" +
+ " t1.k1 = t5.k1\n" +
+ " and t5.k2 = 'United States'\n" +
+ " and t5.k3 in ('OH', 'MT', 'NM')\n" +
+ " and t1.k4 between 150\n" +
+ " and 300\n" +
+ " )\n" +
+ " or (\n" +
+ " t1.k1 = t5.k1\n" +
+ " and t5.k2 = 'United States'\n" +
+ " and t5.k3 in ('TX', 'MO', 'MI')\n" +
+ " and t1.k4 between 50 and 250\n" +
+ " )\n" +
+ " );";
+ SelectStmt stmt = (SelectStmt) UtFrameUtils.parseAndAnalyzeStmt(sql, ctx);
+ stmt.rewriteExprs(new Analyzer(ctx.getCatalog(), ctx).getExprRewriter());
+ String rewritedFragment1 = "(((`t1`.`k2` = `t4`.`k2`) AND (`t3`.`k3` = `t1`.`k3`)) AND ((((((`t3`.`k1` = 'D')" +
+ " AND (`t4`.`k3` = '2 yr Degree')) AND ((`t1`.`k4` >= 100.00) AND (`t1`.`k4` <= 150.00))) AND" +
+ " (`t4`.`k4` = 3)) OR ((((`t3`.`k1` = 'S') AND (`t4`.`k3` = 'Secondary')) AND ((`t1`.`k4` >= 50.00)" +
+ " AND (`t1`.`k4` <= 100.00))) AND (`t4`.`k4` = 1))) OR ((((`t3`.`k1` = 'W') AND " +
+ "(`t4`.`k3` = 'Advanced Degree')) AND ((`t1`.`k4` >= 150.00) AND (`t1`.`k4` <= 200.00)))" +
+ " AND (`t4`.`k4` = 1))))";
+ String rewritedFragment2 = "(((`t1`.`k1` = `t5`.`k1`) AND (`t5`.`k2` = 'United States')) AND" +
+ " ((((`t5`.`k3` IN ('CO', 'IL', 'MN')) AND ((`t1`.`k4` >= 100) AND (`t1`.`k4` <= 200)))" +
+ " OR ((`t5`.`k3` IN ('OH', 'MT', 'NM')) AND ((`t1`.`k4` >= 150) AND (`t1`.`k4` <= 300))))" +
+ " OR ((`t5`.`k3` IN ('TX', 'MO', 'MI')) AND ((`t1`.`k4` >= 50) AND (`t1`.`k4` <= 250)))))";
+ Assert.assertTrue(stmt.toSql().contains(rewritedFragment1));
+ Assert.assertTrue(stmt.toSql().contains(rewritedFragment2));
+
+ String sql2 = "select\n" +
+ " avg(t1.k4)\n" +
+ "from\n" +
+ " db1.tbl1 t1,\n" +
+ " db1.tbl1 t2\n" +
+ "where\n" +
+ "(\n" +
+ " t1.k1 = t2.k3\n" +
+ " and t2.k2 = 'United States'\n" +
+ " and t2.k3 in ('CO', 'IL', 'MN')\n" +
+ " and t1.k4 between 100\n" +
+ " and 200\n" +
+ ")\n" +
+ "or (\n" +
+ " t1.k1 = t2.k1\n" +
+ " and t2.k2 = 'United States1'\n" +
+ " and t2.k3 in ('OH', 'MT', 'NM')\n" +
+ " and t1.k4 between 150\n" +
+ " and 300\n" +
+ ")\n" +
+ "or (\n" +
+ " t1.k1 = t2.k1\n" +
+ " and t2.k2 = 'United States'\n" +
+ " and t2.k3 in ('TX', 'MO', 'MI')\n" +
+ " and t1.k4 between 50 and 250\n" +
+ ")";
+ SelectStmt stmt2 = (SelectStmt) UtFrameUtils.parseAndAnalyzeStmt(sql2, ctx);
+ stmt2.rewriteExprs(new Analyzer(ctx.getCatalog(), ctx).getExprRewriter());
+ String fragment3 = "(((((`t1`.`k1` = `t2`.`k3`) AND (`t2`.`k2` = 'United States')) AND " +
+ "(`t2`.`k3` IN ('CO', 'IL', 'MN'))) AND ((`t1`.`k4` >= 100) AND (`t1`.`k4` <= 200))) OR" +
+ " ((((`t1`.`k1` = `t2`.`k1`) AND (`t2`.`k2` = 'United States1')) AND (`t2`.`k3` IN ('OH', 'MT', 'NM')))" +
+ " AND ((`t1`.`k4` >= 150) AND (`t1`.`k4` <= 300)))) OR ((((`t1`.`k1` = `t2`.`k1`) AND " +
+ "(`t2`.`k2` = 'United States')) AND (`t2`.`k3` IN ('TX', 'MO', 'MI'))) AND ((`t1`.`k4` >= 50)" +
+ " AND (`t1`.`k4` <= 250)))";
+ Assert.assertTrue(stmt2.toSql().contains(fragment3));
+
+ String sql3 = "select\n" +
+ " avg(t1.k4)\n" +
+ "from\n" +
+ " db1.tbl1 t1,\n" +
+ " db1.tbl1 t2\n" +
+ "where\n" +
+ " t1.k1 = t2.k3 or t1.k1 = t2.k3 or t1.k1 = t2.k3";
+ SelectStmt stmt3 = (SelectStmt) UtFrameUtils.parseAndAnalyzeStmt(sql3, ctx);
+ stmt3.rewriteExprs(new Analyzer(ctx.getCatalog(), ctx).getExprRewriter());
+ Assert.assertFalse(stmt3.toSql().contains("((`t1`.`k1` = `t2`.`k3`) OR (`t1`.`k1` = `t2`.`k3`)) OR" +
+ " (`t1`.`k1` = `t2`.`k3`)"));
+
+ String sql4 = "select\n" +
+ " avg(t1.k4)\n" +
+ "from\n" +
+ " db1.tbl1 t1,\n" +
+ " db1.tbl1 t2\n" +
+ "where\n" +
+ " t1.k1 = t2.k2 or t1.k1 = t2.k3 or t1.k1 = t2.k3";
+ SelectStmt stmt4 = (SelectStmt) UtFrameUtils.parseAndAnalyzeStmt(sql4, ctx);
+ stmt4.rewriteExprs(new Analyzer(ctx.getCatalog(), ctx).getExprRewriter());
+ Assert.assertTrue(stmt4.toSql().contains("(`t1`.`k1` = `t2`.`k2`) OR (`t1`.`k1` = `t2`.`k3`)"));
+
+ String sql5 = "select\n" +
+ " avg(t1.k4)\n" +
+ "from\n" +
+ " db1.tbl1 t1,\n" +
+ " db1.tbl1 t2\n" +
+ "where\n" +
+ " t2.k1 is not null or t1.k1 is not null or t1.k1 is not null";
+ SelectStmt stmt5 = (SelectStmt) UtFrameUtils.parseAndAnalyzeStmt(sql5, ctx);
+ stmt5.rewriteExprs(new Analyzer(ctx.getCatalog(), ctx).getExprRewriter());
+ Assert.assertTrue(stmt5.toSql().contains("(`t2`.`k1` IS NOT NULL) OR (`t1`.`k1` IS NOT NULL)"));
+ Assert.assertEquals(2, stmt5.toSql().split(" OR ").length);
+
+ String sql6 = "select\n" +
+ " avg(t1.k4)\n" +
+ "from\n" +
+ " db1.tbl1 t1,\n" +
+ " db1.tbl1 t2\n" +
+ "where\n" +
+ " t2.k1 is not null or t1.k1 is not null and t1.k1 is not null";
+ SelectStmt stmt6 = (SelectStmt) UtFrameUtils.parseAndAnalyzeStmt(sql6, ctx);
+ stmt6.rewriteExprs(new Analyzer(ctx.getCatalog(), ctx).getExprRewriter());
+ Assert.assertTrue(stmt6.toSql().contains("(`t2`.`k1` IS NOT NULL) OR (`t1`.`k1` IS NOT NULL)"));
+ Assert.assertEquals(2, stmt6.toSql().split(" OR ").length);
+
+ String sql7 = "select\n" +
+ " avg(t1.k4)\n" +
+ "from\n" +
+ " db1.tbl1 t1,\n" +
+ " db1.tbl1 t2\n" +
+ "where\n" +
+ " t2.k1 is not null or t1.k1 is not null and t1.k2 is not null";
+ SelectStmt stmt7 = (SelectStmt) UtFrameUtils.parseAndAnalyzeStmt(sql7, ctx);
+ stmt7.rewriteExprs(new Analyzer(ctx.getCatalog(), ctx).getExprRewriter());
+ Assert.assertTrue(stmt7.toSql().contains("(`t2`.`k1` IS NOT NULL) OR ((`t1`.`k1` IS NOT NULL) " +
+ "AND (`t1`.`k2` IS NOT NULL))"));
+
+ String sql8 = "select\n" +
+ " avg(t1.k4)\n" +
+ "from\n" +
+ " db1.tbl1 t1,\n" +
+ " db1.tbl1 t2\n" +
+ "where\n" +
+ " t2.k1 is not null and t1.k1 is not null and t1.k1 is not null";
+ SelectStmt stmt8 = (SelectStmt) UtFrameUtils.parseAndAnalyzeStmt(sql8, ctx);
+ stmt8.rewriteExprs(new Analyzer(ctx.getCatalog(), ctx).getExprRewriter());
+ Assert.assertTrue(stmt8.toSql().contains("((`t2`.`k1` IS NOT NULL) AND (`t1`.`k1` IS NOT NULL))" +
+ " AND (`t1`.`k1` IS NOT NULL)"));
}
}
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@doris.apache.org
For additional commands, e-mail: commits-help@doris.apache.org