You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by ja...@apache.org on 2023/06/01 03:26:23 UTC

[arrow-datafusion] branch main updated: Rewrite large OR chains as IN lists (#6414)

This is an automated email from the ASF dual-hosted git repository.

jakevin pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new 2f264ab154 Rewrite large OR chains as IN lists (#6414)
2f264ab154 is described below

commit 2f264ab154bdf1b2be7737c75d46d1efefacb737
Author: Armin Primadi <ap...@gmail.com>
AuthorDate: Thu Jun 1 10:26:16 2023 +0700

    Rewrite large OR chains as IN lists (#6414)
    
    * Naive large or chains simplifier
    
    * Fix test
    
    * Added assert_text_eq for line diff comparison for easier debuggin
    
    * Fix test
    
    * Fix test
    
    * Add test
    
    * Add test
    
    * Add test
    
    * Add more tests
    
    * Move OrInListSimplifier to its own file
    
    * Rename "left-heavy" to "left-deep" to be consistent with DB parlance
    
    * Remove no longer used dev-dependencies on benchmarks
    
    * Fix benchmark test
    
    ---------
    
    Co-authored-by: Andrew Lamb <an...@nerdnetworks.org>
---
 datafusion/core/src/physical_plan/planner.rs       |  2 +-
 .../tests/sqllogictests/test_files/predicates.slt  | 70 ++++++++++++++++
 .../sqllogictests/test_files/tpch/q12.slt.part     |  6 +-
 .../sqllogictests/test_files/tpch/q19.slt.part     |  6 +-
 .../src/simplify_expressions/expr_simplifier.rs    | 76 ++++++++++++++----
 .../optimizer/src/simplify_expressions/mod.rs      |  1 +
 .../simplify_expressions/or_in_list_simplifier.rs  | 92 ++++++++++++++++++++++
 .../src/simplify_expressions/simplify_exprs.rs     |  4 +-
 8 files changed, 231 insertions(+), 26 deletions(-)

diff --git a/datafusion/core/src/physical_plan/planner.rs b/datafusion/core/src/physical_plan/planner.rs
index 35b209c7c5..4527440906 100644
--- a/datafusion/core/src/physical_plan/planner.rs
+++ b/datafusion/core/src/physical_plan/planner.rs
@@ -2230,7 +2230,7 @@ mod tests {
         let execution_plan = plan(&logical_plan).await?;
         // verify that the plan correctly adds cast from Int64(1) to Utf8, and the const will be evaluated.
 
-        let expected = "expr: [(BinaryExpr { left: BinaryExpr { left: Column { name: \"c1\", index: 0 }, op: Eq, right: Literal { value: Utf8(\"1\") } }, op: Or, right: BinaryExpr { left: Column { name: \"c1\", index: 0 }, op: Eq, right: Literal { value: Utf8(\"a\") } } }";
+        let expected = "expr: [(BinaryExpr { left: BinaryExpr { left: Column { name: \"c1\", index: 0 }, op: Eq, right: Literal { value: Utf8(\"a\") } }, op: Or, right: BinaryExpr { left: Column { name: \"c1\", index: 0 }, op: Eq, right: Literal { value: Utf8(\"1\") } } }";
 
         let actual = format!("{execution_plan:?}");
         assert!(actual.contains(expected), "{}", actual);
diff --git a/datafusion/core/tests/sqllogictests/test_files/predicates.slt b/datafusion/core/tests/sqllogictests/test_files/predicates.slt
index 952a369642..f37495c47c 100644
--- a/datafusion/core/tests/sqllogictests/test_files/predicates.slt
+++ b/datafusion/core/tests/sqllogictests/test_files/predicates.slt
@@ -249,6 +249,76 @@ SELECT * FROM test WHERE column1 IN ('foo', 'Bar', 'fazzz')
 foo
 fazzz
 
+
+###
+# Test logical plan simplifies large OR chains
+###
+
+statement ok
+set datafusion.explain.logical_plan_only = true
+
+# Number of OR statements is less than or equal to threshold
+query TT
+EXPLAIN SELECT * FROM test WHERE column1 = 'foo' OR column1 = 'bar' OR column1 = 'fazzz'
+----
+logical_plan
+Filter: test.column1 = Utf8("foo") OR test.column1 = Utf8("bar") OR test.column1 = Utf8("fazzz")
+--TableScan: test projection=[column1]
+
+# Number of OR statements is greater than threshold
+query TT
+EXPLAIN SELECT * FROM test WHERE column1 = 'foo' OR column1 = 'bar' OR column1 = 'fazzz' OR column1 = 'barfoo'
+----
+logical_plan
+Filter: test.column1 IN ([Utf8("foo"), Utf8("bar"), Utf8("fazzz"), Utf8("barfoo")])
+--TableScan: test projection=[column1]
+
+# Complex OR statements
+query TT
+EXPLAIN SELECT * FROM test WHERE column1 = 'foo' OR column1 = 'bar' OR column1 = 'fazzz' OR column1 = 'barfoo' OR false OR column1 = 'foobar'
+----
+logical_plan
+Filter: test.column1 IN ([Utf8("foo"), Utf8("bar"), Utf8("fazzz"), Utf8("barfoo"), Utf8("foobar")])
+--TableScan: test projection=[column1]
+
+# Balanced OR structures
+query TT
+EXPLAIN SELECT * FROM test WHERE (column1 = 'foo' OR column1 = 'bar') OR (column1 = 'fazzz' OR column1 = 'barfoo')
+----
+logical_plan
+Filter: test.column1 IN ([Utf8("foo"), Utf8("bar"), Utf8("fazzz"), Utf8("barfoo")])
+--TableScan: test projection=[column1]
+
+# Right-deep OR structures
+query TT
+EXPLAIN SELECT * FROM test WHERE column1 = 'foo' OR (column1 = 'bar' OR (column1 = 'fazzz' OR column1 = 'barfoo'))
+----
+logical_plan
+Filter: test.column1 IN ([Utf8("foo"), Utf8("bar"), Utf8("fazzz"), Utf8("barfoo")])
+--TableScan: test projection=[column1]
+
+# Not simplifiable, mixed column
+query TT
+EXPLAIN SELECT * FROM aggregate_test_100
+WHERE (c2 = 1 OR c3 = 100) OR (c2 = 2 OR c2 = 3 OR c2 = 4)
+----
+logical_plan
+Filter: aggregate_test_100.c2 = Int8(1) OR aggregate_test_100.c3 = Int16(100) OR aggregate_test_100.c2 = Int8(2) OR aggregate_test_100.c2 = Int8(3) OR aggregate_test_100.c2 = Int8(4)
+--TableScan: aggregate_test_100 projection=[c1, c2, c3, c4, c5, c6, c7, c8, c9, c10, c11, c12, c13], partial_filters=[aggregate_test_100.c2 = Int8(1) OR aggregate_test_100.c3 = Int16(100) OR aggregate_test_100.c2 = Int8(2) OR aggregate_test_100.c2 = Int8(3) OR aggregate_test_100.c2 = Int8(4)]
+
+# Partially simplifiable, mixed column
+query TT
+EXPLAIN SELECT * FROM aggregate_test_100
+WHERE (c2 = 1 OR c3 = 100) OR (c2 = 2 OR c2 = 3 OR c2 = 4 OR c2 = 5)
+----
+logical_plan
+Filter: aggregate_test_100.c2 = Int8(1) OR aggregate_test_100.c3 = Int16(100) OR aggregate_test_100.c2 IN ([Int8(2), Int8(3), Int8(4), Int8(5)])
+--TableScan: aggregate_test_100 projection=[c1, c2, c3, c4, c5, c6, c7, c8, c9, c10, c11, c12, c13], partial_filters=[aggregate_test_100.c2 = Int8(1) OR aggregate_test_100.c3 = Int16(100) OR aggregate_test_100.c2 IN ([Int8(2), Int8(3), Int8(4), Int8(5)])]
+
+statement ok
+set datafusion.explain.logical_plan_only = false
+
+
 # async fn test_expect_all
 query IR
 SELECT int_col, double_col FROM alltypes_plain where int_col > 0 EXCEPT ALL SELECT int_col, double_col FROM alltypes_plain where int_col < 1
diff --git a/datafusion/core/tests/sqllogictests/test_files/tpch/q12.slt.part b/datafusion/core/tests/sqllogictests/test_files/tpch/q12.slt.part
index c1670e6d5c..fdada35952 100644
--- a/datafusion/core/tests/sqllogictests/test_files/tpch/q12.slt.part
+++ b/datafusion/core/tests/sqllogictests/test_files/tpch/q12.slt.part
@@ -55,8 +55,8 @@ Sort: lineitem.l_shipmode ASC NULLS LAST
 ------Projection: lineitem.l_shipmode, orders.o_orderpriority
 --------Inner Join: lineitem.l_orderkey = orders.o_orderkey
 ----------Projection: lineitem.l_orderkey, lineitem.l_shipmode
-------------Filter: (lineitem.l_shipmode = Utf8("SHIP") OR lineitem.l_shipmode = Utf8("MAIL")) AND lineitem.l_commitdate < lineitem.l_receiptdate AND lineitem.l_shipdate < lineitem.l_commitdate AND lineitem.l_receiptdate >= Date32("8766") AND lineitem.l_receiptdate < Date32("9131")
---------------TableScan: lineitem projection=[l_orderkey, l_shipdate, l_commitdate, l_receiptdate, l_shipmode], partial_filters=[lineitem.l_shipmode = Utf8("SHIP") OR lineitem.l_shipmode = Utf8("MAIL"), lineitem.l_commitdate < lineitem.l_receiptdate, lineitem.l_shipdate < lineitem.l_commitdate, lineitem.l_receiptdate >= Date32("8766"), lineitem.l_receiptdate < Date32("9131")]
+------------Filter: (lineitem.l_shipmode = Utf8("MAIL") OR lineitem.l_shipmode = Utf8("SHIP")) AND lineitem.l_commitdate < lineitem.l_receiptdate AND lineitem.l_shipdate < lineitem.l_commitdate AND lineitem.l_receiptdate >= Date32("8766") AND lineitem.l_receiptdate < Date32("9131")
+--------------TableScan: lineitem projection=[l_orderkey, l_shipdate, l_commitdate, l_receiptdate, l_shipmode], partial_filters=[lineitem.l_shipmode = Utf8("MAIL") OR lineitem.l_shipmode = Utf8("SHIP"), lineitem.l_commitdate < lineitem.l_receiptdate, lineitem.l_shipdate < lineitem.l_commitdate, lineitem.l_receiptdate >= Date32("8766"), lineitem.l_receiptdate < Date32("9131")]
 ----------TableScan: orders projection=[o_orderkey, o_orderpriority]
 physical_plan
 SortPreservingMergeExec: [l_shipmode@0 ASC NULLS LAST]
@@ -73,7 +73,7 @@ SortPreservingMergeExec: [l_shipmode@0 ASC NULLS LAST]
 ----------------------RepartitionExec: partitioning=Hash([Column { name: "l_orderkey", index: 0 }], 4), input_partitions=4
 ------------------------ProjectionExec: expr=[l_orderkey@0 as l_orderkey, l_shipmode@4 as l_shipmode]
 --------------------------CoalesceBatchesExec: target_batch_size=8192
-----------------------------FilterExec: (l_shipmode@4 = SHIP OR l_shipmode@4 = MAIL) AND l_commitdate@2 < l_receiptdate@3 AND l_shipdate@1 < l_commitdate@2 AND l_receiptdate@3 >= 8766 AND l_receiptdate@3 < 9131
+----------------------------FilterExec: (l_shipmode@4 = MAIL OR l_shipmode@4 = SHIP) AND l_commitdate@2 < l_receiptdate@3 AND l_shipdate@1 < l_commitdate@2 AND l_receiptdate@3 >= 8766 AND l_receiptdate@3 < 9131
 ------------------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1
 --------------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/sqllogictests/test_files/tpch/data/lineitem.tbl]]}, projection=[l_orderkey, l_shipdate, l_commitdate, l_receiptdate, l_shipmode], has_header=false
 --------------------CoalesceBatchesExec: target_batch_size=8192
diff --git a/datafusion/core/tests/sqllogictests/test_files/tpch/q19.slt.part b/datafusion/core/tests/sqllogictests/test_files/tpch/q19.slt.part
index 06c6f5ed59..1a91fed124 100644
--- a/datafusion/core/tests/sqllogictests/test_files/tpch/q19.slt.part
+++ b/datafusion/core/tests/sqllogictests/test_files/tpch/q19.slt.part
@@ -59,8 +59,8 @@ Projection: SUM(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount) AS re
 ----Projection: lineitem.l_extendedprice, lineitem.l_discount
 ------Inner Join: lineitem.l_partkey = part.p_partkey Filter: part.p_brand = Utf8("Brand#12") AND part.p_container IN ([Utf8("SM CASE"), Utf8("SM BOX"), Utf8("SM PACK"), Utf8("SM PKG")]) AND lineitem.l_quantity >= Decimal128(Some(100),15,2) AND lineitem.l_quantity <= Decimal128(Some(1100),15,2) AND part.p_size <= Int32(5) OR part.p_brand = Utf8("Brand#23") AND part.p_container IN ([Utf8("MED BAG"), Utf8("MED BOX"), Utf8("MED PKG"), Utf8("MED PACK")]) AND lineitem.l_quantity >= Decimal128 [...]
 --------Projection: lineitem.l_partkey, lineitem.l_quantity, lineitem.l_extendedprice, lineitem.l_discount
-----------Filter: (lineitem.l_quantity >= Decimal128(Some(100),15,2) AND lineitem.l_quantity <= Decimal128(Some(1100),15,2) OR lineitem.l_quantity >= Decimal128(Some(1000),15,2) AND lineitem.l_quantity <= Decimal128(Some(2000),15,2) OR lineitem.l_quantity >= Decimal128(Some(2000),15,2) AND lineitem.l_quantity <= Decimal128(Some(3000),15,2)) AND (lineitem.l_shipmode = Utf8("AIR REG") OR lineitem.l_shipmode = Utf8("AIR")) AND lineitem.l_shipinstruct = Utf8("DELIVER IN PERSON")
-------------TableScan: lineitem projection=[l_partkey, l_quantity, l_extendedprice, l_discount, l_shipinstruct, l_shipmode], partial_filters=[lineitem.l_shipmode = Utf8("AIR REG") OR lineitem.l_shipmode = Utf8("AIR"), lineitem.l_shipinstruct = Utf8("DELIVER IN PERSON"), lineitem.l_quantity >= Decimal128(Some(100),15,2) AND lineitem.l_quantity <= Decimal128(Some(1100),15,2) OR lineitem.l_quantity >= Decimal128(Some(1000),15,2) AND lineitem.l_quantity <= Decimal128(Some(2000),15,2) OR line [...]
+----------Filter: (lineitem.l_quantity >= Decimal128(Some(100),15,2) AND lineitem.l_quantity <= Decimal128(Some(1100),15,2) OR lineitem.l_quantity >= Decimal128(Some(1000),15,2) AND lineitem.l_quantity <= Decimal128(Some(2000),15,2) OR lineitem.l_quantity >= Decimal128(Some(2000),15,2) AND lineitem.l_quantity <= Decimal128(Some(3000),15,2)) AND (lineitem.l_shipmode = Utf8("AIR") OR lineitem.l_shipmode = Utf8("AIR REG")) AND lineitem.l_shipinstruct = Utf8("DELIVER IN PERSON")
+------------TableScan: lineitem projection=[l_partkey, l_quantity, l_extendedprice, l_discount, l_shipinstruct, l_shipmode], partial_filters=[lineitem.l_shipmode = Utf8("AIR") OR lineitem.l_shipmode = Utf8("AIR REG"), lineitem.l_shipinstruct = Utf8("DELIVER IN PERSON"), lineitem.l_quantity >= Decimal128(Some(100),15,2) AND lineitem.l_quantity <= Decimal128(Some(1100),15,2) OR lineitem.l_quantity >= Decimal128(Some(1000),15,2) AND lineitem.l_quantity <= Decimal128(Some(2000),15,2) OR line [...]
 --------Filter: (part.p_brand = Utf8("Brand#12") AND part.p_container IN ([Utf8("SM CASE"), Utf8("SM BOX"), Utf8("SM PACK"), Utf8("SM PKG")]) AND part.p_size <= Int32(5) OR part.p_brand = Utf8("Brand#23") AND part.p_container IN ([Utf8("MED BAG"), Utf8("MED BOX"), Utf8("MED PKG"), Utf8("MED PACK")]) AND part.p_size <= Int32(10) OR part.p_brand = Utf8("Brand#34") AND part.p_container IN ([Utf8("LG CASE"), Utf8("LG BOX"), Utf8("LG PACK"), Utf8("LG PKG")]) AND part.p_size <= Int32(15)) AND  [...]
 ----------TableScan: part projection=[p_partkey, p_brand, p_size, p_container], partial_filters=[part.p_size >= Int32(1), part.p_brand = Utf8("Brand#12") AND part.p_container IN ([Utf8("SM CASE"), Utf8("SM BOX"), Utf8("SM PACK"), Utf8("SM PKG")]) AND part.p_size <= Int32(5) OR part.p_brand = Utf8("Brand#23") AND part.p_container IN ([Utf8("MED BAG"), Utf8("MED BOX"), Utf8("MED PKG"), Utf8("MED PACK")]) AND part.p_size <= Int32(10) OR part.p_brand = Utf8("Brand#34") AND part.p_container I [...]
 physical_plan
@@ -75,7 +75,7 @@ ProjectionExec: expr=[SUM(lineitem.l_extendedprice * Int64(1) - lineitem.l_disco
 ----------------RepartitionExec: partitioning=Hash([Column { name: "l_partkey", index: 0 }], 4), input_partitions=4
 ------------------ProjectionExec: expr=[l_partkey@0 as l_partkey, l_quantity@1 as l_quantity, l_extendedprice@2 as l_extendedprice, l_discount@3 as l_discount]
 --------------------CoalesceBatchesExec: target_batch_size=8192
-----------------------FilterExec: (l_quantity@1 >= Some(100),15,2 AND l_quantity@1 <= Some(1100),15,2 OR l_quantity@1 >= Some(1000),15,2 AND l_quantity@1 <= Some(2000),15,2 OR l_quantity@1 >= Some(2000),15,2 AND l_quantity@1 <= Some(3000),15,2) AND (l_shipmode@5 = AIR REG OR l_shipmode@5 = AIR) AND l_shipinstruct@4 = DELIVER IN PERSON
+----------------------FilterExec: (l_quantity@1 >= Some(100),15,2 AND l_quantity@1 <= Some(1100),15,2 OR l_quantity@1 >= Some(1000),15,2 AND l_quantity@1 <= Some(2000),15,2 OR l_quantity@1 >= Some(2000),15,2 AND l_quantity@1 <= Some(3000),15,2) AND (l_shipmode@5 = AIR OR l_shipmode@5 = AIR REG) AND l_shipinstruct@4 = DELIVER IN PERSON
 ------------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1
 --------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/sqllogictests/test_files/tpch/data/lineitem.tbl]]}, projection=[l_partkey, l_quantity, l_extendedprice, l_discount, l_shipinstruct, l_shipmode], has_header=false
 --------------CoalesceBatchesExec: target_batch_size=8192
diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs
index a8d6876a23..98fec3f7c9 100644
--- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs
+++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs
@@ -19,7 +19,9 @@
 
 use std::ops::Not;
 
+use super::or_in_list_simplifier::OrInListSimplifier;
 use super::utils::*;
+
 use crate::analyzer::type_coercion::TypeCoercionRewriter;
 use crate::simplify_expressions::regex::simplify_regex_expr;
 use arrow::{
@@ -116,6 +118,7 @@ impl<S: SimplifyInfo> ExprSimplifier<S> {
     pub fn simplify(&self, expr: Expr) -> Result<Expr> {
         let mut simplifier = Simplifier::new(&self.info);
         let mut const_evaluator = ConstEvaluator::try_new(self.info.execution_props())?;
+        let mut or_in_list_simplifier = OrInListSimplifier::new();
 
         // TODO iterate until no changes are made during rewrite
         // (evaluating constants can enable new simplifications and
@@ -123,6 +126,7 @@ impl<S: SimplifyInfo> ExprSimplifier<S> {
         // https://github.com/apache/arrow-datafusion/issues/1160
         expr.rewrite(&mut const_evaluator)?
             .rewrite(&mut simplifier)?
+            .rewrite(&mut or_in_list_simplifier)?
             // run both passes twice to try an minimize simplifications that we missed
             .rewrite(&mut const_evaluator)?
             .rewrite(&mut simplifier)
@@ -432,17 +436,37 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> {
             {
                 let first_val = list[0].clone();
                 if negated {
-                    list.into_iter()
-                        .skip(1)
-                        .fold((*expr.clone()).not_eq(first_val), |acc, y| {
-                            (*expr.clone()).not_eq(y).and(acc)
-                        })
+                    list.into_iter().skip(1).fold(
+                        (*expr.clone()).not_eq(first_val),
+                        |acc, y| {
+                            // Note that `A and B and C and D` is a left-deep tree structure
+                            // as such we want to maintain this structure as much as possible
+                            // to avoid reordering the expression during each optimization
+                            // pass.
+                            //
+                            // Left-deep tree structure for `A and B and C and D`:
+                            // ```
+                            //        &
+                            //       / \
+                            //      &   D
+                            //     / \
+                            //    &   C
+                            //   / \
+                            //  A   B
+                            // ```
+                            //
+                            // The code below maintain the left-deep tree structure.
+                            acc.and((*expr.clone()).not_eq(y))
+                        },
+                    )
                 } else {
-                    list.into_iter()
-                        .skip(1)
-                        .fold((*expr.clone()).eq(first_val), |acc, y| {
-                            (*expr.clone()).eq(y).or(acc)
-                        })
+                    list.into_iter().skip(1).fold(
+                        (*expr.clone()).eq(first_val),
+                        |acc, y| {
+                            // Same reasoning as above
+                            acc.or((*expr.clone()).eq(y))
+                        },
+                    )
                 }
             }
             //
@@ -2888,11 +2912,11 @@ mod tests {
 
         assert_eq!(
             simplify(in_list(col("c1"), vec![lit(1), lit(2)], false)),
-            col("c1").eq(lit(2)).or(col("c1").eq(lit(1)))
+            col("c1").eq(lit(1)).or(col("c1").eq(lit(2)))
         );
         assert_eq!(
             simplify(in_list(col("c1"), vec![lit(1), lit(2)], true)),
-            col("c1").not_eq(lit(2)).and(col("c1").not_eq(lit(1)))
+            col("c1").not_eq(lit(1)).and(col("c1").not_eq(lit(2)))
         );
 
         let subquery = Arc::new(test_table_scan_with_name("test").unwrap());
@@ -2918,7 +2942,7 @@ mod tests {
         let subquery2 =
             scalar_subquery(Arc::new(test_table_scan_with_name("test2").unwrap()));
 
-        // c1 NOT IN (<subquery1>, <subquery2>) -> c1 != <subquery2> AND c1 != <subquery1>
+        // c1 NOT IN (<subquery1>, <subquery2>) -> c1 != <subquery1> AND c1 != <subquery2>
         assert_eq!(
             simplify(in_list(
                 col("c1"),
@@ -2926,18 +2950,36 @@ mod tests {
                 true
             )),
             col("c1")
-                .not_eq(subquery2.clone())
-                .and(col("c1").not_eq(subquery1.clone()))
+                .not_eq(subquery1.clone())
+                .and(col("c1").not_eq(subquery2.clone()))
         );
 
-        // c1 IN (<subquery1>, <subquery2>) -> c1 == <subquery2> OR c1 == <subquery1>
+        // c1 IN (<subquery1>, <subquery2>) -> c1 == <subquery1> OR c1 == <subquery2>
         assert_eq!(
             simplify(in_list(
                 col("c1"),
                 vec![subquery1.clone(), subquery2.clone()],
                 false
             )),
-            col("c1").eq(subquery2).or(col("c1").eq(subquery1))
+            col("c1").eq(subquery1).or(col("c1").eq(subquery2))
+        );
+
+        // c1 NOT IN (1, 2, 3, 4) OR c1 NOT IN (5, 6, 7, 8) ->
+        // c1 NOT IN (1, 2, 3, 4) OR c1 NOT IN (5, 6, 7, 8)
+        let expr = in_list(col("c1"), vec![lit(1), lit(2), lit(3), lit(4)], true).or(
+            in_list(col("c1"), vec![lit(5), lit(6), lit(7), lit(8)], true),
+        );
+        assert_eq!(simplify(expr.clone()), expr);
+    }
+
+    #[test]
+    fn simplify_large_or() {
+        let expr = (0..5)
+            .map(|i| col("c1").eq(lit(i)))
+            .fold(lit(false), |acc, e| acc.or(e));
+        assert_eq!(
+            simplify(expr),
+            in_list(col("c1"), (0..5).map(lit).collect(), false),
         );
     }
 
diff --git a/datafusion/optimizer/src/simplify_expressions/mod.rs b/datafusion/optimizer/src/simplify_expressions/mod.rs
index 975976e1f5..dfa0fe7043 100644
--- a/datafusion/optimizer/src/simplify_expressions/mod.rs
+++ b/datafusion/optimizer/src/simplify_expressions/mod.rs
@@ -17,6 +17,7 @@
 
 pub mod context;
 pub mod expr_simplifier;
+mod or_in_list_simplifier;
 mod regex;
 pub mod simplify_exprs;
 mod utils;
diff --git a/datafusion/optimizer/src/simplify_expressions/or_in_list_simplifier.rs b/datafusion/optimizer/src/simplify_expressions/or_in_list_simplifier.rs
new file mode 100644
index 0000000000..10f3aa0278
--- /dev/null
+++ b/datafusion/optimizer/src/simplify_expressions/or_in_list_simplifier.rs
@@ -0,0 +1,92 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+//! This module implements a rule that simplifies OR expressions into IN list expressions
+
+use datafusion_common::tree_node::TreeNodeRewriter;
+use datafusion_common::Result;
+use datafusion_expr::expr::InList;
+use datafusion_expr::{BinaryExpr, Expr, Operator};
+
+/// Combine multiple OR expressions into a single IN list expression if possible
+///
+/// i.e. `a = 1 OR a = 2 OR a = 3` -> `a IN (1, 2, 3)`
+pub(super) struct OrInListSimplifier {}
+
+impl OrInListSimplifier {
+    pub(super) fn new() -> Self {
+        Self {}
+    }
+}
+
+impl TreeNodeRewriter for OrInListSimplifier {
+    type N = Expr;
+
+    fn mutate(&mut self, expr: Expr) -> Result<Expr> {
+        if let Expr::BinaryExpr(BinaryExpr { left, op, right }) = &expr {
+            if *op == Operator::Or {
+                let left = as_inlist(left);
+                let right = as_inlist(right);
+                if let (Some(lhs), Some(rhs)) = (left, right) {
+                    if lhs.expr.try_into_col().is_ok()
+                        && rhs.expr.try_into_col().is_ok()
+                        && lhs.expr == rhs.expr
+                        && !lhs.negated
+                        && !rhs.negated
+                    {
+                        let mut list = vec![];
+                        list.extend(lhs.list);
+                        list.extend(rhs.list);
+                        let merged_inlist = InList {
+                            expr: lhs.expr,
+                            list,
+                            negated: false,
+                        };
+                        return Ok(Expr::InList(merged_inlist));
+                    }
+                }
+            }
+        }
+
+        Ok(expr)
+    }
+}
+
+/// Try to convert an expression to an in-list expression
+fn as_inlist(expr: &Expr) -> Option<InList> {
+    match expr {
+        Expr::InList(inlist) => Some(inlist.clone()),
+        Expr::BinaryExpr(BinaryExpr { left, op, right }) if *op == Operator::Eq => {
+            let unboxed_left = *left.clone();
+            let unboxed_right = *right.clone();
+            match (&unboxed_left, &unboxed_right) {
+                (Expr::Column(_), Expr::Literal(_)) => Some(InList {
+                    expr: left.clone(),
+                    list: vec![unboxed_right],
+                    negated: false,
+                }),
+                (Expr::Literal(_), Expr::Column(_)) => Some(InList {
+                    expr: right.clone(),
+                    list: vec![unboxed_left],
+                    negated: false,
+                }),
+                _ => None,
+            }
+        }
+        _ => None,
+    }
+}
diff --git a/datafusion/optimizer/src/simplify_expressions/simplify_exprs.rs b/datafusion/optimizer/src/simplify_expressions/simplify_exprs.rs
index 6b0496a0cc..239497d9fa 100644
--- a/datafusion/optimizer/src/simplify_expressions/simplify_exprs.rs
+++ b/datafusion/optimizer/src/simplify_expressions/simplify_exprs.rs
@@ -655,7 +655,7 @@ mod tests {
             .filter(col("d").in_list(vec![lit(1), lit(2), lit(3)], false).not())?
             .build()?;
         let expected =
-            "Filter: test.d != Int32(3) AND test.d != Int32(2) AND test.d != Int32(1)\
+            "Filter: test.d != Int32(1) AND test.d != Int32(2) AND test.d != Int32(3)\
         \n  TableScan: test";
 
         assert_optimized_plan_eq(&plan, expected)
@@ -669,7 +669,7 @@ mod tests {
             .filter(col("d").in_list(vec![lit(1), lit(2), lit(3)], true).not())?
             .build()?;
         let expected =
-            "Filter: test.d = Int32(3) OR test.d = Int32(2) OR test.d = Int32(1)\
+            "Filter: test.d = Int32(1) OR test.d = Int32(2) OR test.d = Int32(3)\
         \n  TableScan: test";
 
         assert_optimized_plan_eq(&plan, expected)