You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by ja...@apache.org on 2023/06/01 03:26:23 UTC
[arrow-datafusion] branch main updated: Rewrite large OR chains as IN lists (#6414)
This is an automated email from the ASF dual-hosted git repository.
jakevin pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new 2f264ab154 Rewrite large OR chains as IN lists (#6414)
2f264ab154 is described below
commit 2f264ab154bdf1b2be7737c75d46d1efefacb737
Author: Armin Primadi <ap...@gmail.com>
AuthorDate: Thu Jun 1 10:26:16 2023 +0700
Rewrite large OR chains as IN lists (#6414)
* Naive large or chains simplifier
* Fix test
* Added assert_text_eq for line diff comparison for easier debuggin
* Fix test
* Fix test
* Add test
* Add test
* Add test
* Add more tests
* Move OrInListSimplifier to its own file
* Rename "left-heavy" to "left-deep" to be consistent with DB parlance
* Remove no longer used dev-dependencies on benchmarks
* Fix benchmark test
---------
Co-authored-by: Andrew Lamb <an...@nerdnetworks.org>
---
datafusion/core/src/physical_plan/planner.rs | 2 +-
.../tests/sqllogictests/test_files/predicates.slt | 70 ++++++++++++++++
.../sqllogictests/test_files/tpch/q12.slt.part | 6 +-
.../sqllogictests/test_files/tpch/q19.slt.part | 6 +-
.../src/simplify_expressions/expr_simplifier.rs | 76 ++++++++++++++----
.../optimizer/src/simplify_expressions/mod.rs | 1 +
.../simplify_expressions/or_in_list_simplifier.rs | 92 ++++++++++++++++++++++
.../src/simplify_expressions/simplify_exprs.rs | 4 +-
8 files changed, 231 insertions(+), 26 deletions(-)
diff --git a/datafusion/core/src/physical_plan/planner.rs b/datafusion/core/src/physical_plan/planner.rs
index 35b209c7c5..4527440906 100644
--- a/datafusion/core/src/physical_plan/planner.rs
+++ b/datafusion/core/src/physical_plan/planner.rs
@@ -2230,7 +2230,7 @@ mod tests {
let execution_plan = plan(&logical_plan).await?;
// verify that the plan correctly adds cast from Int64(1) to Utf8, and the const will be evaluated.
- let expected = "expr: [(BinaryExpr { left: BinaryExpr { left: Column { name: \"c1\", index: 0 }, op: Eq, right: Literal { value: Utf8(\"1\") } }, op: Or, right: BinaryExpr { left: Column { name: \"c1\", index: 0 }, op: Eq, right: Literal { value: Utf8(\"a\") } } }";
+ let expected = "expr: [(BinaryExpr { left: BinaryExpr { left: Column { name: \"c1\", index: 0 }, op: Eq, right: Literal { value: Utf8(\"a\") } }, op: Or, right: BinaryExpr { left: Column { name: \"c1\", index: 0 }, op: Eq, right: Literal { value: Utf8(\"1\") } } }";
let actual = format!("{execution_plan:?}");
assert!(actual.contains(expected), "{}", actual);
diff --git a/datafusion/core/tests/sqllogictests/test_files/predicates.slt b/datafusion/core/tests/sqllogictests/test_files/predicates.slt
index 952a369642..f37495c47c 100644
--- a/datafusion/core/tests/sqllogictests/test_files/predicates.slt
+++ b/datafusion/core/tests/sqllogictests/test_files/predicates.slt
@@ -249,6 +249,76 @@ SELECT * FROM test WHERE column1 IN ('foo', 'Bar', 'fazzz')
foo
fazzz
+
+###
+# Test logical plan simplifies large OR chains
+###
+
+statement ok
+set datafusion.explain.logical_plan_only = true
+
+# Number of OR statements is less than or equal to threshold
+query TT
+EXPLAIN SELECT * FROM test WHERE column1 = 'foo' OR column1 = 'bar' OR column1 = 'fazzz'
+----
+logical_plan
+Filter: test.column1 = Utf8("foo") OR test.column1 = Utf8("bar") OR test.column1 = Utf8("fazzz")
+--TableScan: test projection=[column1]
+
+# Number of OR statements is greater than threshold
+query TT
+EXPLAIN SELECT * FROM test WHERE column1 = 'foo' OR column1 = 'bar' OR column1 = 'fazzz' OR column1 = 'barfoo'
+----
+logical_plan
+Filter: test.column1 IN ([Utf8("foo"), Utf8("bar"), Utf8("fazzz"), Utf8("barfoo")])
+--TableScan: test projection=[column1]
+
+# Complex OR statements
+query TT
+EXPLAIN SELECT * FROM test WHERE column1 = 'foo' OR column1 = 'bar' OR column1 = 'fazzz' OR column1 = 'barfoo' OR false OR column1 = 'foobar'
+----
+logical_plan
+Filter: test.column1 IN ([Utf8("foo"), Utf8("bar"), Utf8("fazzz"), Utf8("barfoo"), Utf8("foobar")])
+--TableScan: test projection=[column1]
+
+# Balanced OR structures
+query TT
+EXPLAIN SELECT * FROM test WHERE (column1 = 'foo' OR column1 = 'bar') OR (column1 = 'fazzz' OR column1 = 'barfoo')
+----
+logical_plan
+Filter: test.column1 IN ([Utf8("foo"), Utf8("bar"), Utf8("fazzz"), Utf8("barfoo")])
+--TableScan: test projection=[column1]
+
+# Right-deep OR structures
+query TT
+EXPLAIN SELECT * FROM test WHERE column1 = 'foo' OR (column1 = 'bar' OR (column1 = 'fazzz' OR column1 = 'barfoo'))
+----
+logical_plan
+Filter: test.column1 IN ([Utf8("foo"), Utf8("bar"), Utf8("fazzz"), Utf8("barfoo")])
+--TableScan: test projection=[column1]
+
+# Not simplifiable, mixed column
+query TT
+EXPLAIN SELECT * FROM aggregate_test_100
+WHERE (c2 = 1 OR c3 = 100) OR (c2 = 2 OR c2 = 3 OR c2 = 4)
+----
+logical_plan
+Filter: aggregate_test_100.c2 = Int8(1) OR aggregate_test_100.c3 = Int16(100) OR aggregate_test_100.c2 = Int8(2) OR aggregate_test_100.c2 = Int8(3) OR aggregate_test_100.c2 = Int8(4)
+--TableScan: aggregate_test_100 projection=[c1, c2, c3, c4, c5, c6, c7, c8, c9, c10, c11, c12, c13], partial_filters=[aggregate_test_100.c2 = Int8(1) OR aggregate_test_100.c3 = Int16(100) OR aggregate_test_100.c2 = Int8(2) OR aggregate_test_100.c2 = Int8(3) OR aggregate_test_100.c2 = Int8(4)]
+
+# Partially simplifiable, mixed column
+query TT
+EXPLAIN SELECT * FROM aggregate_test_100
+WHERE (c2 = 1 OR c3 = 100) OR (c2 = 2 OR c2 = 3 OR c2 = 4 OR c2 = 5)
+----
+logical_plan
+Filter: aggregate_test_100.c2 = Int8(1) OR aggregate_test_100.c3 = Int16(100) OR aggregate_test_100.c2 IN ([Int8(2), Int8(3), Int8(4), Int8(5)])
+--TableScan: aggregate_test_100 projection=[c1, c2, c3, c4, c5, c6, c7, c8, c9, c10, c11, c12, c13], partial_filters=[aggregate_test_100.c2 = Int8(1) OR aggregate_test_100.c3 = Int16(100) OR aggregate_test_100.c2 IN ([Int8(2), Int8(3), Int8(4), Int8(5)])]
+
+statement ok
+set datafusion.explain.logical_plan_only = false
+
+
# async fn test_expect_all
query IR
SELECT int_col, double_col FROM alltypes_plain where int_col > 0 EXCEPT ALL SELECT int_col, double_col FROM alltypes_plain where int_col < 1
diff --git a/datafusion/core/tests/sqllogictests/test_files/tpch/q12.slt.part b/datafusion/core/tests/sqllogictests/test_files/tpch/q12.slt.part
index c1670e6d5c..fdada35952 100644
--- a/datafusion/core/tests/sqllogictests/test_files/tpch/q12.slt.part
+++ b/datafusion/core/tests/sqllogictests/test_files/tpch/q12.slt.part
@@ -55,8 +55,8 @@ Sort: lineitem.l_shipmode ASC NULLS LAST
------Projection: lineitem.l_shipmode, orders.o_orderpriority
--------Inner Join: lineitem.l_orderkey = orders.o_orderkey
----------Projection: lineitem.l_orderkey, lineitem.l_shipmode
-------------Filter: (lineitem.l_shipmode = Utf8("SHIP") OR lineitem.l_shipmode = Utf8("MAIL")) AND lineitem.l_commitdate < lineitem.l_receiptdate AND lineitem.l_shipdate < lineitem.l_commitdate AND lineitem.l_receiptdate >= Date32("8766") AND lineitem.l_receiptdate < Date32("9131")
---------------TableScan: lineitem projection=[l_orderkey, l_shipdate, l_commitdate, l_receiptdate, l_shipmode], partial_filters=[lineitem.l_shipmode = Utf8("SHIP") OR lineitem.l_shipmode = Utf8("MAIL"), lineitem.l_commitdate < lineitem.l_receiptdate, lineitem.l_shipdate < lineitem.l_commitdate, lineitem.l_receiptdate >= Date32("8766"), lineitem.l_receiptdate < Date32("9131")]
+------------Filter: (lineitem.l_shipmode = Utf8("MAIL") OR lineitem.l_shipmode = Utf8("SHIP")) AND lineitem.l_commitdate < lineitem.l_receiptdate AND lineitem.l_shipdate < lineitem.l_commitdate AND lineitem.l_receiptdate >= Date32("8766") AND lineitem.l_receiptdate < Date32("9131")
+--------------TableScan: lineitem projection=[l_orderkey, l_shipdate, l_commitdate, l_receiptdate, l_shipmode], partial_filters=[lineitem.l_shipmode = Utf8("MAIL") OR lineitem.l_shipmode = Utf8("SHIP"), lineitem.l_commitdate < lineitem.l_receiptdate, lineitem.l_shipdate < lineitem.l_commitdate, lineitem.l_receiptdate >= Date32("8766"), lineitem.l_receiptdate < Date32("9131")]
----------TableScan: orders projection=[o_orderkey, o_orderpriority]
physical_plan
SortPreservingMergeExec: [l_shipmode@0 ASC NULLS LAST]
@@ -73,7 +73,7 @@ SortPreservingMergeExec: [l_shipmode@0 ASC NULLS LAST]
----------------------RepartitionExec: partitioning=Hash([Column { name: "l_orderkey", index: 0 }], 4), input_partitions=4
------------------------ProjectionExec: expr=[l_orderkey@0 as l_orderkey, l_shipmode@4 as l_shipmode]
--------------------------CoalesceBatchesExec: target_batch_size=8192
-----------------------------FilterExec: (l_shipmode@4 = SHIP OR l_shipmode@4 = MAIL) AND l_commitdate@2 < l_receiptdate@3 AND l_shipdate@1 < l_commitdate@2 AND l_receiptdate@3 >= 8766 AND l_receiptdate@3 < 9131
+----------------------------FilterExec: (l_shipmode@4 = MAIL OR l_shipmode@4 = SHIP) AND l_commitdate@2 < l_receiptdate@3 AND l_shipdate@1 < l_commitdate@2 AND l_receiptdate@3 >= 8766 AND l_receiptdate@3 < 9131
------------------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1
--------------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/sqllogictests/test_files/tpch/data/lineitem.tbl]]}, projection=[l_orderkey, l_shipdate, l_commitdate, l_receiptdate, l_shipmode], has_header=false
--------------------CoalesceBatchesExec: target_batch_size=8192
diff --git a/datafusion/core/tests/sqllogictests/test_files/tpch/q19.slt.part b/datafusion/core/tests/sqllogictests/test_files/tpch/q19.slt.part
index 06c6f5ed59..1a91fed124 100644
--- a/datafusion/core/tests/sqllogictests/test_files/tpch/q19.slt.part
+++ b/datafusion/core/tests/sqllogictests/test_files/tpch/q19.slt.part
@@ -59,8 +59,8 @@ Projection: SUM(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount) AS re
----Projection: lineitem.l_extendedprice, lineitem.l_discount
------Inner Join: lineitem.l_partkey = part.p_partkey Filter: part.p_brand = Utf8("Brand#12") AND part.p_container IN ([Utf8("SM CASE"), Utf8("SM BOX"), Utf8("SM PACK"), Utf8("SM PKG")]) AND lineitem.l_quantity >= Decimal128(Some(100),15,2) AND lineitem.l_quantity <= Decimal128(Some(1100),15,2) AND part.p_size <= Int32(5) OR part.p_brand = Utf8("Brand#23") AND part.p_container IN ([Utf8("MED BAG"), Utf8("MED BOX"), Utf8("MED PKG"), Utf8("MED PACK")]) AND lineitem.l_quantity >= Decimal128 [...]
--------Projection: lineitem.l_partkey, lineitem.l_quantity, lineitem.l_extendedprice, lineitem.l_discount
-----------Filter: (lineitem.l_quantity >= Decimal128(Some(100),15,2) AND lineitem.l_quantity <= Decimal128(Some(1100),15,2) OR lineitem.l_quantity >= Decimal128(Some(1000),15,2) AND lineitem.l_quantity <= Decimal128(Some(2000),15,2) OR lineitem.l_quantity >= Decimal128(Some(2000),15,2) AND lineitem.l_quantity <= Decimal128(Some(3000),15,2)) AND (lineitem.l_shipmode = Utf8("AIR REG") OR lineitem.l_shipmode = Utf8("AIR")) AND lineitem.l_shipinstruct = Utf8("DELIVER IN PERSON")
-------------TableScan: lineitem projection=[l_partkey, l_quantity, l_extendedprice, l_discount, l_shipinstruct, l_shipmode], partial_filters=[lineitem.l_shipmode = Utf8("AIR REG") OR lineitem.l_shipmode = Utf8("AIR"), lineitem.l_shipinstruct = Utf8("DELIVER IN PERSON"), lineitem.l_quantity >= Decimal128(Some(100),15,2) AND lineitem.l_quantity <= Decimal128(Some(1100),15,2) OR lineitem.l_quantity >= Decimal128(Some(1000),15,2) AND lineitem.l_quantity <= Decimal128(Some(2000),15,2) OR line [...]
+----------Filter: (lineitem.l_quantity >= Decimal128(Some(100),15,2) AND lineitem.l_quantity <= Decimal128(Some(1100),15,2) OR lineitem.l_quantity >= Decimal128(Some(1000),15,2) AND lineitem.l_quantity <= Decimal128(Some(2000),15,2) OR lineitem.l_quantity >= Decimal128(Some(2000),15,2) AND lineitem.l_quantity <= Decimal128(Some(3000),15,2)) AND (lineitem.l_shipmode = Utf8("AIR") OR lineitem.l_shipmode = Utf8("AIR REG")) AND lineitem.l_shipinstruct = Utf8("DELIVER IN PERSON")
+------------TableScan: lineitem projection=[l_partkey, l_quantity, l_extendedprice, l_discount, l_shipinstruct, l_shipmode], partial_filters=[lineitem.l_shipmode = Utf8("AIR") OR lineitem.l_shipmode = Utf8("AIR REG"), lineitem.l_shipinstruct = Utf8("DELIVER IN PERSON"), lineitem.l_quantity >= Decimal128(Some(100),15,2) AND lineitem.l_quantity <= Decimal128(Some(1100),15,2) OR lineitem.l_quantity >= Decimal128(Some(1000),15,2) AND lineitem.l_quantity <= Decimal128(Some(2000),15,2) OR line [...]
--------Filter: (part.p_brand = Utf8("Brand#12") AND part.p_container IN ([Utf8("SM CASE"), Utf8("SM BOX"), Utf8("SM PACK"), Utf8("SM PKG")]) AND part.p_size <= Int32(5) OR part.p_brand = Utf8("Brand#23") AND part.p_container IN ([Utf8("MED BAG"), Utf8("MED BOX"), Utf8("MED PKG"), Utf8("MED PACK")]) AND part.p_size <= Int32(10) OR part.p_brand = Utf8("Brand#34") AND part.p_container IN ([Utf8("LG CASE"), Utf8("LG BOX"), Utf8("LG PACK"), Utf8("LG PKG")]) AND part.p_size <= Int32(15)) AND [...]
----------TableScan: part projection=[p_partkey, p_brand, p_size, p_container], partial_filters=[part.p_size >= Int32(1), part.p_brand = Utf8("Brand#12") AND part.p_container IN ([Utf8("SM CASE"), Utf8("SM BOX"), Utf8("SM PACK"), Utf8("SM PKG")]) AND part.p_size <= Int32(5) OR part.p_brand = Utf8("Brand#23") AND part.p_container IN ([Utf8("MED BAG"), Utf8("MED BOX"), Utf8("MED PKG"), Utf8("MED PACK")]) AND part.p_size <= Int32(10) OR part.p_brand = Utf8("Brand#34") AND part.p_container I [...]
physical_plan
@@ -75,7 +75,7 @@ ProjectionExec: expr=[SUM(lineitem.l_extendedprice * Int64(1) - lineitem.l_disco
----------------RepartitionExec: partitioning=Hash([Column { name: "l_partkey", index: 0 }], 4), input_partitions=4
------------------ProjectionExec: expr=[l_partkey@0 as l_partkey, l_quantity@1 as l_quantity, l_extendedprice@2 as l_extendedprice, l_discount@3 as l_discount]
--------------------CoalesceBatchesExec: target_batch_size=8192
-----------------------FilterExec: (l_quantity@1 >= Some(100),15,2 AND l_quantity@1 <= Some(1100),15,2 OR l_quantity@1 >= Some(1000),15,2 AND l_quantity@1 <= Some(2000),15,2 OR l_quantity@1 >= Some(2000),15,2 AND l_quantity@1 <= Some(3000),15,2) AND (l_shipmode@5 = AIR REG OR l_shipmode@5 = AIR) AND l_shipinstruct@4 = DELIVER IN PERSON
+----------------------FilterExec: (l_quantity@1 >= Some(100),15,2 AND l_quantity@1 <= Some(1100),15,2 OR l_quantity@1 >= Some(1000),15,2 AND l_quantity@1 <= Some(2000),15,2 OR l_quantity@1 >= Some(2000),15,2 AND l_quantity@1 <= Some(3000),15,2) AND (l_shipmode@5 = AIR OR l_shipmode@5 = AIR REG) AND l_shipinstruct@4 = DELIVER IN PERSON
------------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1
--------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/sqllogictests/test_files/tpch/data/lineitem.tbl]]}, projection=[l_partkey, l_quantity, l_extendedprice, l_discount, l_shipinstruct, l_shipmode], has_header=false
--------------CoalesceBatchesExec: target_batch_size=8192
diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs
index a8d6876a23..98fec3f7c9 100644
--- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs
+++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs
@@ -19,7 +19,9 @@
use std::ops::Not;
+use super::or_in_list_simplifier::OrInListSimplifier;
use super::utils::*;
+
use crate::analyzer::type_coercion::TypeCoercionRewriter;
use crate::simplify_expressions::regex::simplify_regex_expr;
use arrow::{
@@ -116,6 +118,7 @@ impl<S: SimplifyInfo> ExprSimplifier<S> {
pub fn simplify(&self, expr: Expr) -> Result<Expr> {
let mut simplifier = Simplifier::new(&self.info);
let mut const_evaluator = ConstEvaluator::try_new(self.info.execution_props())?;
+ let mut or_in_list_simplifier = OrInListSimplifier::new();
// TODO iterate until no changes are made during rewrite
// (evaluating constants can enable new simplifications and
@@ -123,6 +126,7 @@ impl<S: SimplifyInfo> ExprSimplifier<S> {
// https://github.com/apache/arrow-datafusion/issues/1160
expr.rewrite(&mut const_evaluator)?
.rewrite(&mut simplifier)?
+ .rewrite(&mut or_in_list_simplifier)?
// run both passes twice to try an minimize simplifications that we missed
.rewrite(&mut const_evaluator)?
.rewrite(&mut simplifier)
@@ -432,17 +436,37 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> {
{
let first_val = list[0].clone();
if negated {
- list.into_iter()
- .skip(1)
- .fold((*expr.clone()).not_eq(first_val), |acc, y| {
- (*expr.clone()).not_eq(y).and(acc)
- })
+ list.into_iter().skip(1).fold(
+ (*expr.clone()).not_eq(first_val),
+ |acc, y| {
+ // Note that `A and B and C and D` is a left-deep tree structure
+ // as such we want to maintain this structure as much as possible
+ // to avoid reordering the expression during each optimization
+ // pass.
+ //
+ // Left-deep tree structure for `A and B and C and D`:
+ // ```
+ // &
+ // / \
+ // & D
+ // / \
+ // & C
+ // / \
+ // A B
+ // ```
+ //
+ // The code below maintain the left-deep tree structure.
+ acc.and((*expr.clone()).not_eq(y))
+ },
+ )
} else {
- list.into_iter()
- .skip(1)
- .fold((*expr.clone()).eq(first_val), |acc, y| {
- (*expr.clone()).eq(y).or(acc)
- })
+ list.into_iter().skip(1).fold(
+ (*expr.clone()).eq(first_val),
+ |acc, y| {
+ // Same reasoning as above
+ acc.or((*expr.clone()).eq(y))
+ },
+ )
}
}
//
@@ -2888,11 +2912,11 @@ mod tests {
assert_eq!(
simplify(in_list(col("c1"), vec![lit(1), lit(2)], false)),
- col("c1").eq(lit(2)).or(col("c1").eq(lit(1)))
+ col("c1").eq(lit(1)).or(col("c1").eq(lit(2)))
);
assert_eq!(
simplify(in_list(col("c1"), vec![lit(1), lit(2)], true)),
- col("c1").not_eq(lit(2)).and(col("c1").not_eq(lit(1)))
+ col("c1").not_eq(lit(1)).and(col("c1").not_eq(lit(2)))
);
let subquery = Arc::new(test_table_scan_with_name("test").unwrap());
@@ -2918,7 +2942,7 @@ mod tests {
let subquery2 =
scalar_subquery(Arc::new(test_table_scan_with_name("test2").unwrap()));
- // c1 NOT IN (<subquery1>, <subquery2>) -> c1 != <subquery2> AND c1 != <subquery1>
+ // c1 NOT IN (<subquery1>, <subquery2>) -> c1 != <subquery1> AND c1 != <subquery2>
assert_eq!(
simplify(in_list(
col("c1"),
@@ -2926,18 +2950,36 @@ mod tests {
true
)),
col("c1")
- .not_eq(subquery2.clone())
- .and(col("c1").not_eq(subquery1.clone()))
+ .not_eq(subquery1.clone())
+ .and(col("c1").not_eq(subquery2.clone()))
);
- // c1 IN (<subquery1>, <subquery2>) -> c1 == <subquery2> OR c1 == <subquery1>
+ // c1 IN (<subquery1>, <subquery2>) -> c1 == <subquery1> OR c1 == <subquery2>
assert_eq!(
simplify(in_list(
col("c1"),
vec![subquery1.clone(), subquery2.clone()],
false
)),
- col("c1").eq(subquery2).or(col("c1").eq(subquery1))
+ col("c1").eq(subquery1).or(col("c1").eq(subquery2))
+ );
+
+ // c1 NOT IN (1, 2, 3, 4) OR c1 NOT IN (5, 6, 7, 8) ->
+ // c1 NOT IN (1, 2, 3, 4) OR c1 NOT IN (5, 6, 7, 8)
+ let expr = in_list(col("c1"), vec![lit(1), lit(2), lit(3), lit(4)], true).or(
+ in_list(col("c1"), vec![lit(5), lit(6), lit(7), lit(8)], true),
+ );
+ assert_eq!(simplify(expr.clone()), expr);
+ }
+
+ #[test]
+ fn simplify_large_or() {
+ let expr = (0..5)
+ .map(|i| col("c1").eq(lit(i)))
+ .fold(lit(false), |acc, e| acc.or(e));
+ assert_eq!(
+ simplify(expr),
+ in_list(col("c1"), (0..5).map(lit).collect(), false),
);
}
diff --git a/datafusion/optimizer/src/simplify_expressions/mod.rs b/datafusion/optimizer/src/simplify_expressions/mod.rs
index 975976e1f5..dfa0fe7043 100644
--- a/datafusion/optimizer/src/simplify_expressions/mod.rs
+++ b/datafusion/optimizer/src/simplify_expressions/mod.rs
@@ -17,6 +17,7 @@
pub mod context;
pub mod expr_simplifier;
+mod or_in_list_simplifier;
mod regex;
pub mod simplify_exprs;
mod utils;
diff --git a/datafusion/optimizer/src/simplify_expressions/or_in_list_simplifier.rs b/datafusion/optimizer/src/simplify_expressions/or_in_list_simplifier.rs
new file mode 100644
index 0000000000..10f3aa0278
--- /dev/null
+++ b/datafusion/optimizer/src/simplify_expressions/or_in_list_simplifier.rs
@@ -0,0 +1,92 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+//! This module implements a rule that simplifies OR expressions into IN list expressions
+
+use datafusion_common::tree_node::TreeNodeRewriter;
+use datafusion_common::Result;
+use datafusion_expr::expr::InList;
+use datafusion_expr::{BinaryExpr, Expr, Operator};
+
+/// Combine multiple OR expressions into a single IN list expression if possible
+///
+/// i.e. `a = 1 OR a = 2 OR a = 3` -> `a IN (1, 2, 3)`
+pub(super) struct OrInListSimplifier {}
+
+impl OrInListSimplifier {
+ pub(super) fn new() -> Self {
+ Self {}
+ }
+}
+
+impl TreeNodeRewriter for OrInListSimplifier {
+ type N = Expr;
+
+ fn mutate(&mut self, expr: Expr) -> Result<Expr> {
+ if let Expr::BinaryExpr(BinaryExpr { left, op, right }) = &expr {
+ if *op == Operator::Or {
+ let left = as_inlist(left);
+ let right = as_inlist(right);
+ if let (Some(lhs), Some(rhs)) = (left, right) {
+ if lhs.expr.try_into_col().is_ok()
+ && rhs.expr.try_into_col().is_ok()
+ && lhs.expr == rhs.expr
+ && !lhs.negated
+ && !rhs.negated
+ {
+ let mut list = vec![];
+ list.extend(lhs.list);
+ list.extend(rhs.list);
+ let merged_inlist = InList {
+ expr: lhs.expr,
+ list,
+ negated: false,
+ };
+ return Ok(Expr::InList(merged_inlist));
+ }
+ }
+ }
+ }
+
+ Ok(expr)
+ }
+}
+
+/// Try to convert an expression to an in-list expression
+fn as_inlist(expr: &Expr) -> Option<InList> {
+ match expr {
+ Expr::InList(inlist) => Some(inlist.clone()),
+ Expr::BinaryExpr(BinaryExpr { left, op, right }) if *op == Operator::Eq => {
+ let unboxed_left = *left.clone();
+ let unboxed_right = *right.clone();
+ match (&unboxed_left, &unboxed_right) {
+ (Expr::Column(_), Expr::Literal(_)) => Some(InList {
+ expr: left.clone(),
+ list: vec![unboxed_right],
+ negated: false,
+ }),
+ (Expr::Literal(_), Expr::Column(_)) => Some(InList {
+ expr: right.clone(),
+ list: vec![unboxed_left],
+ negated: false,
+ }),
+ _ => None,
+ }
+ }
+ _ => None,
+ }
+}
diff --git a/datafusion/optimizer/src/simplify_expressions/simplify_exprs.rs b/datafusion/optimizer/src/simplify_expressions/simplify_exprs.rs
index 6b0496a0cc..239497d9fa 100644
--- a/datafusion/optimizer/src/simplify_expressions/simplify_exprs.rs
+++ b/datafusion/optimizer/src/simplify_expressions/simplify_exprs.rs
@@ -655,7 +655,7 @@ mod tests {
.filter(col("d").in_list(vec![lit(1), lit(2), lit(3)], false).not())?
.build()?;
let expected =
- "Filter: test.d != Int32(3) AND test.d != Int32(2) AND test.d != Int32(1)\
+ "Filter: test.d != Int32(1) AND test.d != Int32(2) AND test.d != Int32(3)\
\n TableScan: test";
assert_optimized_plan_eq(&plan, expected)
@@ -669,7 +669,7 @@ mod tests {
.filter(col("d").in_list(vec![lit(1), lit(2), lit(3)], true).not())?
.build()?;
let expected =
- "Filter: test.d = Int32(3) OR test.d = Int32(2) OR test.d = Int32(1)\
+ "Filter: test.d = Int32(1) OR test.d = Int32(2) OR test.d = Int32(3)\
\n TableScan: test";
assert_optimized_plan_eq(&plan, expected)