You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by al...@apache.org on 2023/01/07 09:09:41 UTC
[arrow-datafusion] branch master updated: bugfix: remove cnf_rewrite in push_down_filter (#4825)
This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/master by this push:
new 2db3d2eae bugfix: remove cnf_rewrite in push_down_filter (#4825)
2db3d2eae is described below
commit 2db3d2eae9f4e51f01cb705677c8c231b43a6172
Author: jakevin <ja...@gmail.com>
AuthorDate: Sat Jan 7 17:09:35 2023 +0800
bugfix: remove cnf_rewrite in push_down_filter (#4825)
* bugfix: remove cnf_rewrite in push_down_filter
* remove cnf_rewrite() and related tests.
---
benchmarks/expected-plans/q19.txt | 13 +-
benchmarks/expected-plans/q7.txt | 2 +-
datafusion/core/tests/sql/joins.rs | 2 +-
datafusion/core/tests/sql/predicates.rs | 6 +-
datafusion/optimizer/src/optimizer.rs | 2 +-
datafusion/optimizer/src/push_down_filter.rs | 66 +++++---
datafusion/optimizer/src/utils.rs | 233 +--------------------------
7 files changed, 60 insertions(+), 264 deletions(-)
diff --git a/benchmarks/expected-plans/q19.txt b/benchmarks/expected-plans/q19.txt
index 3efc3718d..969ad02d4 100644
--- a/benchmarks/expected-plans/q19.txt
+++ b/benchmarks/expected-plans/q19.txt
@@ -1,9 +1,8 @@
Projection: SUM(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount) AS revenue
Aggregate: groupBy=[[]], aggr=[[SUM(CAST(lineitem.l_extendedprice AS Decimal128(38, 4)) * CAST(Decimal128(Some(100),23,2) - CAST(lineitem.l_discount AS Decimal128(23, 2)) AS Decimal128(38, 4))) AS SUM(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount)]]
- Projection: lineitem.l_extendedprice, lineitem.l_discount
- Filter: part.p_brand = Utf8("Brand#12") AND part.p_container IN ([Utf8("SM CASE"), Utf8("SM BOX"), Utf8("SM PACK"), Utf8("SM PKG")]) AND lineitem.l_quantity >= Decimal128(Some(100),15,2) AND lineitem.l_quantity <= Decimal128(Some(1100),15,2) AND part.p_size <= Int32(5) OR part.p_brand = Utf8("Brand#23") AND part.p_container IN ([Utf8("MED BAG"), Utf8("MED BOX"), Utf8("MED PKG"), Utf8("MED PACK")]) AND lineitem.l_quantity >= Decimal128(Some(1000),15,2) AND lineitem.l_quantity <= Dec [...]
- Inner Join: lineitem.l_partkey = part.p_partkey
- Filter: (lineitem.l_quantity >= Decimal128(Some(100),15,2) AND lineitem.l_quantity <= Decimal128(Some(1100),15,2) OR lineitem.l_quantity >= Decimal128(Some(1000),15,2) AND lineitem.l_quantity <= Decimal128(Some(2000),15,2) OR lineitem.l_quantity >= Decimal128(Some(2000),15,2) AND lineitem.l_quantity <= Decimal128(Some(3000),15,2)) AND (lineitem.l_shipmode = Utf8("AIR REG") OR lineitem.l_shipmode = Utf8("AIR")) AND lineitem.l_shipinstruct = Utf8("DELIVER IN PERSON")
- TableScan: lineitem projection=[l_partkey, l_quantity, l_extendedprice, l_discount, l_shipinstruct, l_shipmode]
- Filter: (part.p_brand = Utf8("Brand#12") AND part.p_container IN ([Utf8("SM CASE"), Utf8("SM BOX"), Utf8("SM PACK"), Utf8("SM PKG")]) AND part.p_size <= Int32(5) OR part.p_brand = Utf8("Brand#23") AND part.p_container IN ([Utf8("MED BAG"), Utf8("MED BOX"), Utf8("MED PKG"), Utf8("MED PACK")]) AND part.p_size <= Int32(10) OR part.p_brand = Utf8("Brand#34") AND part.p_container IN ([Utf8("LG CASE"), Utf8("LG BOX"), Utf8("LG PACK"), Utf8("LG PKG")]) AND part.p_size <= Int32(15)) AN [...]
- TableScan: part projection=[p_partkey, p_brand, p_size, p_container]
+ Filter: part.p_brand = Utf8("Brand#12") AND part.p_container IN ([Utf8("SM CASE"), Utf8("SM BOX"), Utf8("SM PACK"), Utf8("SM PKG")]) AND lineitem.l_quantity >= Decimal128(Some(100),15,2) AND lineitem.l_quantity <= Decimal128(Some(1100),15,2) AND part.p_size <= Int32(5) OR part.p_brand = Utf8("Brand#23") AND part.p_container IN ([Utf8("MED BAG"), Utf8("MED BOX"), Utf8("MED PKG"), Utf8("MED PACK")]) AND lineitem.l_quantity >= Decimal128(Some(1000),15,2) AND lineitem.l_quantity <= Decim [...]
+ Inner Join: lineitem.l_partkey = part.p_partkey
+ Filter: (lineitem.l_quantity >= Decimal128(Some(100),15,2) AND lineitem.l_quantity <= Decimal128(Some(1100),15,2) OR lineitem.l_quantity >= Decimal128(Some(1000),15,2) AND lineitem.l_quantity <= Decimal128(Some(2000),15,2) OR lineitem.l_quantity >= Decimal128(Some(2000),15,2) AND lineitem.l_quantity <= Decimal128(Some(3000),15,2)) AND (lineitem.l_shipmode = Utf8("AIR REG") OR lineitem.l_shipmode = Utf8("AIR")) AND lineitem.l_shipinstruct = Utf8("DELIVER IN PERSON")
+ TableScan: lineitem projection=[l_partkey, l_quantity, l_extendedprice, l_discount, l_shipinstruct, l_shipmode]
+ Filter: (part.p_brand = Utf8("Brand#12") AND part.p_container IN ([Utf8("SM CASE"), Utf8("SM BOX"), Utf8("SM PACK"), Utf8("SM PKG")]) AND part.p_size <= Int32(5) OR part.p_brand = Utf8("Brand#23") AND part.p_container IN ([Utf8("MED BAG"), Utf8("MED BOX"), Utf8("MED PKG"), Utf8("MED PACK")]) AND part.p_size <= Int32(10) OR part.p_brand = Utf8("Brand#34") AND part.p_container IN ([Utf8("LG CASE"), Utf8("LG BOX"), Utf8("LG PACK"), Utf8("LG PKG")]) AND part.p_size <= Int32(15)) AND [...]
+ TableScan: part projection=[p_partkey, p_brand, p_size, p_container]
diff --git a/benchmarks/expected-plans/q7.txt b/benchmarks/expected-plans/q7.txt
index 53deda1b8..bd8c10f8c 100644
--- a/benchmarks/expected-plans/q7.txt
+++ b/benchmarks/expected-plans/q7.txt
@@ -3,7 +3,7 @@ Sort: shipping.supp_nation ASC NULLS LAST, shipping.cust_nation ASC NULLS LAST,
Aggregate: groupBy=[[shipping.supp_nation, shipping.cust_nation, shipping.l_year]], aggr=[[SUM(shipping.volume)]]
SubqueryAlias: shipping
Projection: n1.n_name AS supp_nation, n2.n_name AS cust_nation, datepart(Utf8("YEAR"), lineitem.l_shipdate) AS l_year, CAST(lineitem.l_extendedprice AS Decimal128(38, 4)) * CAST(Decimal128(Some(100),23,2) - CAST(lineitem.l_discount AS Decimal128(23, 2)) AS Decimal128(38, 4)) AS volume
- Filter: (n1.n_name = Utf8("FRANCE") OR n2.n_name = Utf8("FRANCE")) AND (n2.n_name = Utf8("GERMANY") OR n1.n_name = Utf8("GERMANY"))
+ Filter: n1.n_name = Utf8("FRANCE") AND n2.n_name = Utf8("GERMANY") OR n1.n_name = Utf8("GERMANY") AND n2.n_name = Utf8("FRANCE")
Inner Join: customer.c_nationkey = n2.n_nationkey
Inner Join: supplier.s_nationkey = n1.n_nationkey
Inner Join: orders.o_custkey = customer.c_custkey
diff --git a/datafusion/core/tests/sql/joins.rs b/datafusion/core/tests/sql/joins.rs
index 4cc9628d1..1de20c29c 100644
--- a/datafusion/core/tests/sql/joins.rs
+++ b/datafusion/core/tests/sql/joins.rs
@@ -1561,7 +1561,7 @@ async fn reduce_left_join_2() -> Result<()> {
let expected = vec![
"Explain [plan_type:Utf8, plan:Utf8]",
" Projection: t1.t1_id, t1.t1_name, t1.t1_int, t2.t2_id, t2.t2_name, t2.t2_int [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N, t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
- " Filter: t2.t2_int < UInt32(10) OR t1.t1_int > UInt32(2) [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N, t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
+ " Filter: t2.t2_int < UInt32(10) OR t1.t1_int > UInt32(2) AND t2.t2_name != Utf8(\"w\") [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N, t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
" Inner Join: t1.t1_id = t2.t2_id [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N, t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
" TableScan: t1 projection=[t1_id, t1_name, t1_int] [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]",
" Filter: t2.t2_int < UInt32(10) OR t2.t2_name != Utf8(\"w\") [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
diff --git a/datafusion/core/tests/sql/predicates.rs b/datafusion/core/tests/sql/predicates.rs
index 61d509a2d..1e8888ce4 100644
--- a/datafusion/core/tests/sql/predicates.rs
+++ b/datafusion/core/tests/sql/predicates.rs
@@ -590,10 +590,8 @@ async fn multiple_or_predicates() -> Result<()> {
" Projection: lineitem.l_partkey [l_partkey:Int64]",
" Filter: part.p_brand = Utf8(\"Brand#12\") AND lineitem.l_quantity >= Decimal128(Some(100),15,2) AND lineitem.l_quantity <= Decimal128(Some(1100),15,2) AND part.p_size <= Int32(5) OR part.p_brand = Utf8(\"Brand#23\") AND lineitem.l_quantity >= Decimal128(Some(1000),15,2) AND lineitem.l_quantity <= Decimal128(Some(2000),15,2) AND part.p_size <= Int32(10) OR part.p_brand = Utf8(\"Brand#34\") AND lineitem.l_quantity >= Decimal128(Some(2000),15,2) AND lineitem.l_quantity <= Decim [...]
" Inner Join: lineitem.l_partkey = part.p_partkey [l_partkey:Int64, l_quantity:Decimal128(15, 2), p_partkey:Int64, p_brand:Utf8, p_size:Int32]",
- " Projection: lineitem.l_partkey, lineitem.l_quantity [l_partkey:Int64, l_quantity:Decimal128(15, 2)]",
- " Filter: (lineitem.l_quantity >= Decimal128(Some(100),15,2)Decimal128(Some(100),15,2)lineitem.l_quantity AND lineitem.l_quantity <= Decimal128(Some(1100),15,2)Decimal128(Some(1100),15,2)lineitem.l_quantity OR lineitem.l_quantity >= Decimal128(Some(1000),15,2)Decimal128(Some(1000),15,2)lineitem.l_quantity AND lineitem.l_quantity <= Decimal128(Some(2000),15,2)Decimal128(Some(2000),15,2)lineitem.l_quantity OR lineitem.l_quantity >= Decimal128(Some(2000),15,2)Decimal128(Som [...]
- " Projection: lineitem.l_quantity <= Decimal128(Some(1100),15,2) OR lineitem.l_quantity <= Decimal128(Some(2000),15,2) AS lineitem.l_quantity <= Decimal128(Some(1100),15,2) OR lineitem.l_quantity <= Decimal128(Some(2000),15,2)lineitem.l_quantity <= Decimal128(Some(2000),15,2)Decimal128(Some(2000),15,2)lineitem.l_quantitylineitem.l_quantity <= Decimal128(Some(1100),15,2)Decimal128(Some(1100),15,2)lineitem.l_quantity, lineitem.l_quantity <= Decimal128(Some(1100),15,2) OR [...]
- " TableScan: lineitem projection=[l_partkey, l_quantity], partial_filters=[lineitem.l_quantity >= Decimal128(Some(100),15,2) OR lineitem.l_quantity >= Decimal128(Some(1000),15,2) OR lineitem.l_quantity >= Decimal128(Some(2000),15,2), lineitem.l_quantity >= Decimal128(Some(100),15,2) OR lineitem.l_quantity >= Decimal128(Some(1000),15,2) OR lineitem.l_quantity <= Decimal128(Some(3000),15,2), lineitem.l_quantity >= Decimal128(Some(100),15,2) OR lineitem.l_quantity <= De [...]
+ " Filter: lineitem.l_quantity >= Decimal128(Some(100),15,2) AND lineitem.l_quantity <= Decimal128(Some(1100),15,2) OR lineitem.l_quantity >= Decimal128(Some(1000),15,2) AND lineitem.l_quantity <= Decimal128(Some(2000),15,2) OR lineitem.l_quantity >= Decimal128(Some(2000),15,2) AND lineitem.l_quantity <= Decimal128(Some(3000),15,2) [l_partkey:Int64, l_quantity:Decimal128(15, 2)]",
+ " TableScan: lineitem projection=[l_partkey, l_quantity], partial_filters=[lineitem.l_quantity >= Decimal128(Some(100),15,2) AND lineitem.l_quantity <= Decimal128(Some(1100),15,2) OR lineitem.l_quantity >= Decimal128(Some(1000),15,2) AND lineitem.l_quantity <= Decimal128(Some(2000),15,2) OR lineitem.l_quantity >= Decimal128(Some(2000),15,2) AND lineitem.l_quantity <= Decimal128(Some(3000),15,2)] [l_partkey:Int64, l_quantity:Decimal128(15, 2)]",
" Filter: (part.p_brand = Utf8(\"Brand#12\") AND part.p_size <= Int32(5) OR part.p_brand = Utf8(\"Brand#23\") AND part.p_size <= Int32(10) OR part.p_brand = Utf8(\"Brand#34\") AND part.p_size <= Int32(15)) AND part.p_size >= Int32(1) [p_partkey:Int64, p_brand:Utf8, p_size:Int32]",
" TableScan: part projection=[p_partkey, p_brand, p_size], partial_filters=[part.p_size >= Int32(1), part.p_brand = Utf8(\"Brand#12\") AND part.p_size <= Int32(5) OR part.p_brand = Utf8(\"Brand#23\") AND part.p_size <= Int32(10) OR part.p_brand = Utf8(\"Brand#34\") AND part.p_size <= Int32(15)] [p_partkey:Int64, p_brand:Utf8, p_size:Int32]",
];
diff --git a/datafusion/optimizer/src/optimizer.rs b/datafusion/optimizer/src/optimizer.rs
index 3e7df55d5..7390ac204 100644
--- a/datafusion/optimizer/src/optimizer.rs
+++ b/datafusion/optimizer/src/optimizer.rs
@@ -215,12 +215,12 @@ impl Optimizer {
// run it again after running the optimizations that potentially converted
// subqueries to joins
Arc::new(SimplifyExpressions::new()),
+ Arc::new(RewriteDisjunctivePredicate::new()),
Arc::new(EliminateFilter::new()),
Arc::new(EliminateCrossJoin::new()),
Arc::new(CommonSubexprEliminate::new()),
Arc::new(EliminateLimit::new()),
Arc::new(PropagateEmptyRelation::new()),
- Arc::new(RewriteDisjunctivePredicate::new()),
Arc::new(FilterNullJoinKeys::default()),
Arc::new(EliminateOuterJoin::new()),
// Filters can't be pushed down past Limits, we should do PushDownFilter after LimitPushDown
diff --git a/datafusion/optimizer/src/push_down_filter.rs b/datafusion/optimizer/src/push_down_filter.rs
index ff0ea5b23..35c5dacfa 100644
--- a/datafusion/optimizer/src/push_down_filter.rs
+++ b/datafusion/optimizer/src/push_down_filter.rs
@@ -14,7 +14,7 @@
//! Push Down Filter optimizer rule ensures that filters are applied as early as possible in the plan
-use crate::utils::conjunction;
+use crate::utils::{conjunction, split_conjunction};
use crate::{utils, OptimizerConfig, OptimizerRule};
use datafusion_common::{Column, DFSchema, DataFusionError, Result};
use datafusion_expr::{
@@ -421,7 +421,7 @@ fn push_down_join(
) -> Result<Option<LogicalPlan>> {
let mut predicates = match parent_predicate {
Some(parent_predicate) => {
- utils::split_conjunction_owned(utils::cnf_rewrite(parent_predicate.clone()))
+ utils::split_conjunction_owned(parent_predicate.clone())
}
None => vec![],
};
@@ -538,8 +538,21 @@ impl OptimizerRule for PushDownFilter {
let child_plan = filter.input.as_ref();
let new_plan = match child_plan {
LogicalPlan::Filter(child_filter) => {
- let new_predicate =
- and(filter.predicate.clone(), child_filter.predicate.clone());
+ let parents_predicates = split_conjunction(&filter.predicate);
+ let set: HashSet<&&Expr> = parents_predicates.iter().collect();
+
+ let new_predicates = parents_predicates
+ .iter()
+ .chain(
+ split_conjunction(&child_filter.predicate)
+ .iter()
+ .filter(|e| !set.contains(e)),
+ )
+ .map(|e| (*e).clone())
+ .collect::<Vec<_>>();
+ let new_predicate = conjunction(new_predicates).ok_or(
+ DataFusionError::Plan("at least one expression exists".to_string()),
+ )?;
let new_plan = LogicalPlan::Filter(Filter::try_new(
new_predicate,
child_filter.input.clone(),
@@ -638,9 +651,7 @@ impl OptimizerRule for PushDownFilter {
.map(|e| Ok(Column::from_qualified_name(e.display_name()?)))
.collect::<Result<HashSet<_>>>()?;
- let predicates = utils::split_conjunction_owned(utils::cnf_rewrite(
- filter.predicate.clone(),
- ));
+ let predicates = utils::split_conjunction_owned(filter.predicate.clone());
let mut keep_predicates = vec![];
let mut push_predicates = vec![];
@@ -689,19 +700,15 @@ impl OptimizerRule for PushDownFilter {
}
}
LogicalPlan::CrossJoin(CrossJoin { left, right, .. }) => {
- let predicates = utils::split_conjunction_owned(utils::cnf_rewrite(
- filter.predicate.clone(),
- ));
-
+ let predicates = utils::split_conjunction_owned(filter.predicate.clone());
push_down_all_join(predicates, &filter.input, left, right, vec![])?
}
LogicalPlan::TableScan(scan) => {
let mut new_scan_filters = scan.filters.clone();
let mut new_predicate = vec![];
- let filter_predicates = utils::split_conjunction_owned(
- utils::cnf_rewrite(filter.predicate.clone()),
- );
+ let filter_predicates =
+ utils::split_conjunction_owned(filter.predicate.clone());
for filter_expr in &filter_predicates {
let (preserve_filter_node, add_to_provider) =
@@ -754,7 +761,10 @@ impl PushDownFilter {
}
/// replaces columns by its name on the projection.
-fn replace_cols_by_name(e: Expr, replace_map: &HashMap<String, Expr>) -> Result<Expr> {
+pub fn replace_cols_by_name(
+ e: Expr,
+ replace_map: &HashMap<String, Expr>,
+) -> Result<Expr> {
struct ColumnReplacer<'a> {
replace_map: &'a HashMap<String, Expr>,
}
@@ -778,6 +788,7 @@ fn replace_cols_by_name(e: Expr, replace_map: &HashMap<String, Expr>) -> Result<
#[cfg(test)]
mod tests {
use super::*;
+ use crate::rewrite_disjunctive_predicate::RewriteDisjunctivePredicate;
use crate::test::*;
use crate::OptimizerContext;
use arrow::datatypes::{DataType, Field, Schema, SchemaRef};
@@ -801,6 +812,24 @@ mod tests {
Ok(())
}
+ fn assert_optimized_plan_eq_with_rewrite_predicate(
+ plan: &LogicalPlan,
+ expected: &str,
+ ) -> Result<()> {
+ let mut optimized_plan = RewriteDisjunctivePredicate::new()
+ .try_optimize(plan, &OptimizerContext::new())
+ .unwrap()
+ .expect("failed to optimize plan");
+ optimized_plan = PushDownFilter::new()
+ .try_optimize(&optimized_plan, &OptimizerContext::new())
+ .unwrap()
+ .expect("failed to optimize plan");
+ let formatted_plan = format!("{optimized_plan:?}");
+ assert_eq!(plan.schema(), optimized_plan.schema());
+ assert_eq!(expected, formatted_plan);
+ Ok(())
+ }
+
#[test]
fn filter_before_projection() -> Result<()> {
let table_scan = test_table_scan()?;
@@ -2281,20 +2310,19 @@ mod tests {
.build()?;
let expected = "\
- Filter: (test.a = d OR test.b = e) AND (test.a = d OR test.c < UInt32(10)) AND (test.b > UInt32(1) OR test.b = e)\
+ Filter: test.a = d AND test.b > UInt32(1) OR test.b = e AND test.c < UInt32(10)\
\n CrossJoin:\
\n Projection: test.a, test.b, test.c\
\n Filter: test.b > UInt32(1) OR test.c < UInt32(10)\
\n TableScan: test\
\n Projection: test1.a AS d, test1.a AS e\
\n TableScan: test1";
- assert_optimized_plan_eq(&plan, expected)?;
+ assert_optimized_plan_eq_with_rewrite_predicate(&plan, expected)?;
// Originally global state which can help to avoid duplicate Filters been generated and pushed down.
// Now the global state is removed. Need to double confirm that avoid duplicate Filters.
let optimized_plan = PushDownFilter::new()
- .try_optimize(&plan, &OptimizerContext::new())
- .unwrap()
+ .try_optimize(&plan, &OptimizerContext::new())?
.expect("failed to optimize plan");
assert_optimized_plan_eq(&optimized_plan, expected)
}
diff --git a/datafusion/optimizer/src/utils.rs b/datafusion/optimizer/src/utils.rs
index d9c7477cb..ed0f07f55 100644
--- a/datafusion/optimizer/src/utils.rs
+++ b/datafusion/optimizer/src/utils.rs
@@ -29,7 +29,7 @@ use datafusion_expr::{
utils::from_plan,
Expr, Operator,
};
-use std::collections::{HashSet, VecDeque};
+use std::collections::HashSet;
use std::sync::Arc;
/// Convenience rule for writing optimizers: recursively invoke
@@ -170,104 +170,6 @@ fn split_binary_impl<'a>(
}
}
-/// Given a list of lists of [`Expr`]s, returns a list of lists of
-/// [`Expr`]s of expressions where there is one expression from each
-/// from each of the input expressions
-///
-/// For example, given the input `[[a, b], [c], [d, e]]` returns
-/// `[a, c, d], [a, c, e], [b, c, d], [b, c, e]]`.
-fn permutations(mut exprs: VecDeque<Vec<&Expr>>) -> Vec<Vec<&Expr>> {
- let first = if let Some(first) = exprs.pop_front() {
- first
- } else {
- return vec![];
- };
-
- // base case:
- if exprs.is_empty() {
- first.into_iter().map(|e| vec![e]).collect()
- } else {
- first
- .into_iter()
- .flat_map(|expr| {
- permutations(exprs.clone())
- .into_iter()
- .map(|expr_list| {
- // Create [expr, ...] for each permutation
- std::iter::once(expr)
- .chain(expr_list.into_iter())
- .collect::<Vec<&Expr>>()
- })
- .collect::<Vec<Vec<&Expr>>>()
- })
- .collect()
- }
-}
-
-const MAX_CNF_REWRITE_CONJUNCTS: usize = 10;
-
-/// Tries to convert an expression to conjunctive normal form (CNF).
-///
-/// Does not convert the expression if the total number of conjuncts
-/// (exprs ANDed together) would exceed [`MAX_CNF_REWRITE_CONJUNCTS`].
-///
-/// The following expression is in CNF:
-/// `(a OR b) AND (c OR d)`
-///
-/// The following is not in CNF:
-/// `(a AND b) OR c`.
-///
-/// But could be rewrite to a CNF expression:
-/// `(a OR c) AND (b OR c)`.
-///
-///
-/// # Example
-/// ```
-/// # use datafusion_expr::{col, lit};
-/// # use datafusion_optimizer::utils::cnf_rewrite;
-/// // (a=1 AND b=2)OR c = 3
-/// let expr1 = col("a").eq(lit(1)).and(col("b").eq(lit(2)));
-/// let expr2 = col("c").eq(lit(3));
-/// let expr = expr1.or(expr2);
-///
-/// //(a=1 or c=3)AND(b=2 or c=3)
-/// let expr1 = col("a").eq(lit(1)).or(col("c").eq(lit(3)));
-/// let expr2 = col("b").eq(lit(2)).or(col("c").eq(lit(3)));
-/// let expect = expr1.and(expr2);
-/// assert_eq!(expect, cnf_rewrite(expr));
-/// ```
-pub fn cnf_rewrite(expr: Expr) -> Expr {
- // Find all exprs joined by OR
- let disjuncts = split_binary(&expr, Operator::Or);
-
- // For each expr, split now on AND
- // A OR B OR C --> split each A, B and C
- let disjunct_conjuncts: VecDeque<Vec<&Expr>> = disjuncts
- .into_iter()
- .map(|e| split_binary(e, Operator::And))
- .collect::<VecDeque<_>>();
-
- // Decide if we want to distribute the clauses. Heuristic is
- // chosen to avoid creating huge predicates
- let num_conjuncts = disjunct_conjuncts
- .iter()
- .fold(1usize, |sz, exprs| sz.saturating_mul(exprs.len()));
-
- if disjunct_conjuncts.iter().any(|exprs| exprs.len() > 1)
- && num_conjuncts < MAX_CNF_REWRITE_CONJUNCTS
- {
- let or_clauses = permutations(disjunct_conjuncts)
- .into_iter()
- // form the OR clauses( A OR B OR C ..)
- .map(|exprs| disjunction(exprs.into_iter().cloned()).unwrap());
- conjunction(or_clauses).unwrap()
- }
- // otherwise return the original expression
- else {
- expr
- }
-}
-
/// Combines an array of filter expressions into a single filter
/// expression consisting of the input filter expressions joined with
/// logical AND.
@@ -614,7 +516,7 @@ mod tests {
use arrow::datatypes::DataType;
use datafusion_common::Column;
use datafusion_expr::expr::Cast;
- use datafusion_expr::{col, lit, or, utils::expr_to_columns};
+ use datafusion_expr::{col, lit, utils::expr_to_columns};
use std::collections::HashSet;
use std::ops::Add;
@@ -815,135 +717,4 @@ mod tests {
"mismatch rewriting expr_from: {expr_from} to {rewrite_to}"
)
}
-
- #[test]
- fn test_permutations() {
- assert_eq!(make_permutations(vec![]), vec![] as Vec<Vec<Expr>>)
- }
-
- #[test]
- fn test_permutations_one() {
- // [[a]] --> [[a]]
- assert_eq!(
- make_permutations(vec![vec![col("a")]]),
- vec![vec![col("a")]]
- )
- }
-
- #[test]
- fn test_permutations_two() {
- // [[a, b]] --> [[a], [b]]
- assert_eq!(
- make_permutations(vec![vec![col("a"), col("b")]]),
- vec![vec![col("a")], vec![col("b")]]
- )
- }
-
- #[test]
- fn test_permutations_two_and_one() {
- // [[a, b], [c]] --> [[a, c], [b, c]]
- assert_eq!(
- make_permutations(vec![vec![col("a"), col("b")], vec![col("c")]]),
- vec![vec![col("a"), col("c")], vec![col("b"), col("c")]]
- )
- }
-
- #[test]
- fn test_permutations_two_and_one_and_two() {
- // [[a, b], [c], [d, e]] --> [[a, c, d], [a, c, e], [b, c, d], [b, c, e]]
- assert_eq!(
- make_permutations(vec![
- vec![col("a"), col("b")],
- vec![col("c")],
- vec![col("d"), col("e")]
- ]),
- vec![
- vec![col("a"), col("c"), col("d")],
- vec![col("a"), col("c"), col("e")],
- vec![col("b"), col("c"), col("d")],
- vec![col("b"), col("c"), col("e")],
- ]
- )
- }
-
- /// call permutations with owned `Expr`s for easier testing
- fn make_permutations(exprs: impl IntoIterator<Item = Vec<Expr>>) -> Vec<Vec<Expr>> {
- let exprs = exprs.into_iter().collect::<Vec<_>>();
-
- let exprs: VecDeque<Vec<&Expr>> = exprs
- .iter()
- .map(|exprs| exprs.iter().collect::<Vec<&Expr>>())
- .collect();
-
- permutations(exprs)
- .into_iter()
- // copy &Expr --> Expr
- .map(|exprs| exprs.into_iter().cloned().collect())
- .collect()
- }
-
- #[test]
- fn test_rewrite_cnf() {
- let a_1 = col("a").eq(lit(1i64));
- let a_2 = col("a").eq(lit(2i64));
-
- let b_1 = col("b").eq(lit(1i64));
- let b_2 = col("b").eq(lit(2i64));
-
- // Test rewrite on a1_and_b2 and a2_and_b1 -> not change
- let expr1 = and(a_1.clone(), b_2.clone());
- let expect = expr1.clone();
- assert_eq!(expect, cnf_rewrite(expr1));
-
- // Test rewrite on a1_and_b2 and a2_and_b1 -> (((a1 and b2) and a2) and b1)
- let expr1 = and(and(a_1.clone(), b_2.clone()), and(a_2.clone(), b_1.clone()));
- let expect = and(a_1.clone(), b_2.clone())
- .and(a_2.clone())
- .and(b_1.clone());
- assert_eq!(expect, cnf_rewrite(expr1));
-
- // Test rewrite on a1_or_b2 -> not change
- let expr1 = or(a_1.clone(), b_2.clone());
- let expect = expr1.clone();
- assert_eq!(expect, cnf_rewrite(expr1));
-
- // Test rewrite on a1_and_b2 or a2_and_b1 -> a1_or_a2 and a1_or_b1 and b2_or_a2 and b2_or_b1
- let expr1 = or(and(a_1.clone(), b_2.clone()), and(a_2.clone(), b_1.clone()));
- let a1_or_a2 = or(a_1.clone(), a_2.clone());
- let a1_or_b1 = or(a_1.clone(), b_1.clone());
- let b2_or_a2 = or(b_2.clone(), a_2.clone());
- let b2_or_b1 = or(b_2.clone(), b_1.clone());
- let expect = and(a1_or_a2, a1_or_b1).and(b2_or_a2).and(b2_or_b1);
- assert_eq!(expect, cnf_rewrite(expr1));
-
- // Test rewrite on a1_or_b2 or a2_and_b1 -> ( a1_or_a2 or a2 ) and (a1_or_a2 or b1)
- let a1_or_b2 = or(a_1.clone(), b_2.clone());
- let expr1 = or(or(a_1.clone(), b_2.clone()), and(a_2.clone(), b_1.clone()));
- let expect = or(a1_or_b2.clone(), a_2.clone()).and(or(a1_or_b2, b_1.clone()));
- assert_eq!(expect, cnf_rewrite(expr1));
-
- // Test rewrite on a1_or_b2 or a2_or_b1 -> not change
- let expr1 = or(or(a_1, b_2), or(a_2, b_1));
- let expect = expr1.clone();
- assert_eq!(expect, cnf_rewrite(expr1));
- }
-
- #[test]
- fn test_rewrite_cnf_overflow() {
- // in this situation:
- // AND = (a=1 and b=2)
- // rewrite (AND * 10) or (AND * 10), it will produce 10 * 10 = 100 (a=1 or b=2)
- // which cause size expansion.
-
- let mut expr1 = col("test1").eq(lit(1i64));
- let expr2 = col("test2").eq(lit(2i64));
-
- for _i in 0..9 {
- expr1 = expr1.clone().and(expr2.clone());
- }
- let expr3 = expr1.clone();
- let expr = or(expr1, expr3);
-
- assert_eq!(expr, cnf_rewrite(expr.clone()));
- }
}