You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by al...@apache.org on 2023/04/28 14:16:43 UTC

[arrow-datafusion] branch main updated: fix: `common_subexpr_eliminate` and aggregates (#6129)

This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new 7eca046223 fix: `common_subexpr_eliminate` and aggregates (#6129)
7eca046223 is described below

commit 7eca046223cc41ffb1953163c42aece42af2e485
Author: Marco Neumann <ma...@crepererum.net>
AuthorDate: Fri Apr 28 16:16:38 2023 +0200

    fix: `common_subexpr_eliminate` and aggregates (#6129)
    
    * fix: `common_subexpr_eliminate` and aggregates
    
    Fixes #6116.
    
    * test: make `common_subexpr_elimiate` more readable
---
 .../optimizer/src/common_subexpr_eliminate.rs      | 402 +++++++++++++++++----
 1 file changed, 332 insertions(+), 70 deletions(-)

diff --git a/datafusion/optimizer/src/common_subexpr_eliminate.rs b/datafusion/optimizer/src/common_subexpr_eliminate.rs
index 9f97de04a4..be7481f38d 100644
--- a/datafusion/optimizer/src/common_subexpr_eliminate.rs
+++ b/datafusion/optimizer/src/common_subexpr_eliminate.rs
@@ -25,7 +25,9 @@ use arrow::datatypes::DataType;
 use datafusion_common::tree_node::{
     RewriteRecursion, TreeNode, TreeNodeRewriter, TreeNodeVisitor, VisitRecursion,
 };
-use datafusion_common::{DFField, DFSchema, DFSchemaRef, DataFusionError, Result};
+use datafusion_common::{
+    Column, DFField, DFSchema, DFSchemaRef, DataFusionError, Result,
+};
 use datafusion_expr::{
     col,
     logical_plan::{Aggregate, Filter, LogicalPlan, Projection, Sort, Window},
@@ -58,17 +60,14 @@ type Identifier = String;
 pub struct CommonSubexprEliminate {}
 
 impl CommonSubexprEliminate {
-    fn rewrite_expr(
+    fn rewrite_exprs_list(
         &self,
         exprs_list: &[&[Expr]],
         arrays_list: &[&[Vec<(usize, String)>]],
-        input: &LogicalPlan,
         expr_set: &mut ExprSet,
-        config: &dyn OptimizerConfig,
-    ) -> Result<(Vec<Vec<Expr>>, LogicalPlan)> {
-        let mut affected_id = BTreeSet::<Identifier>::new();
-
-        let rewrite_exprs = exprs_list
+        affected_id: &mut BTreeSet<Identifier>,
+    ) -> Result<Vec<Vec<Expr>>> {
+        exprs_list
             .iter()
             .zip(arrays_list.iter())
             .map(|(exprs, arrays)| {
@@ -77,11 +76,25 @@ impl CommonSubexprEliminate {
                     .cloned()
                     .zip(arrays.iter())
                     .map(|(expr, id_array)| {
-                        replace_common_expr(expr, id_array, expr_set, &mut affected_id)
+                        replace_common_expr(expr, id_array, expr_set, affected_id)
                     })
                     .collect::<Result<Vec<_>>>()
             })
-            .collect::<Result<Vec<_>>>()?;
+            .collect::<Result<Vec<_>>>()
+    }
+
+    fn rewrite_expr(
+        &self,
+        exprs_list: &[&[Expr]],
+        arrays_list: &[&[Vec<(usize, String)>]],
+        input: &LogicalPlan,
+        expr_set: &mut ExprSet,
+        config: &dyn OptimizerConfig,
+    ) -> Result<(Vec<Vec<Expr>>, LogicalPlan)> {
+        let mut affected_id = BTreeSet::<Identifier>::new();
+
+        let rewrite_exprs =
+            self.rewrite_exprs_list(exprs_list, arrays_list, expr_set, &mut affected_id)?;
 
         let mut new_input = self
             .try_optimize(input, config)?
@@ -111,7 +124,8 @@ impl OptimizerRule for CommonSubexprEliminate {
                 ..
             }) => {
                 let input_schema = Arc::clone(input.schema());
-                let arrays = to_arrays(expr, input_schema, &mut expr_set)?;
+                let arrays =
+                    to_arrays(expr, input_schema, &mut expr_set, ExprMask::Normal)?;
 
                 let (mut new_expr, new_input) =
                     self.rewrite_expr(&[expr], &[&arrays], input, &mut expr_set, config)?;
@@ -131,6 +145,7 @@ impl OptimizerRule for CommonSubexprEliminate {
                     &mut expr_set,
                     &mut id_array,
                     input_schema,
+                    ExprMask::Normal,
                 )?;
 
                 let (mut new_expr, new_input) = self.rewrite_expr(
@@ -158,7 +173,12 @@ impl OptimizerRule for CommonSubexprEliminate {
                 schema,
             }) => {
                 let input_schema = Arc::clone(input.schema());
-                let arrays = to_arrays(window_expr, input_schema, &mut expr_set)?;
+                let arrays = to_arrays(
+                    window_expr,
+                    input_schema,
+                    &mut expr_set,
+                    ExprMask::Normal,
+                )?;
 
                 let (mut new_expr, new_input) = self.rewrite_expr(
                     &[window_expr],
@@ -181,10 +201,16 @@ impl OptimizerRule for CommonSubexprEliminate {
                 schema,
                 ..
             }) => {
+                // rewrite inputs
                 let input_schema = Arc::clone(input.schema());
-                let group_arrays =
-                    to_arrays(group_expr, Arc::clone(&input_schema), &mut expr_set)?;
-                let aggr_arrays = to_arrays(aggr_expr, input_schema, &mut expr_set)?;
+                let group_arrays = to_arrays(
+                    group_expr,
+                    Arc::clone(&input_schema),
+                    &mut expr_set,
+                    ExprMask::Normal,
+                )?;
+                let aggr_arrays =
+                    to_arrays(aggr_expr, input_schema, &mut expr_set, ExprMask::Normal)?;
 
                 let (mut new_expr, new_input) = self.rewrite_expr(
                     &[group_expr, aggr_expr],
@@ -197,16 +223,93 @@ impl OptimizerRule for CommonSubexprEliminate {
                 let new_aggr_expr = pop_expr(&mut new_expr)?;
                 let new_group_expr = pop_expr(&mut new_expr)?;
 
-                Some(LogicalPlan::Aggregate(Aggregate::try_new_with_schema(
-                    Arc::new(new_input),
-                    new_group_expr,
-                    new_aggr_expr,
-                    schema.clone(),
-                )?))
+                // create potential projection on top
+                let mut expr_set = ExprSet::new();
+                let new_input_schema = Arc::clone(new_input.schema());
+                let aggr_arrays = to_arrays(
+                    &new_aggr_expr,
+                    new_input_schema.clone(),
+                    &mut expr_set,
+                    ExprMask::NormalAndAggregates,
+                )?;
+                let mut affected_id = BTreeSet::<Identifier>::new();
+                let mut rewritten = self.rewrite_exprs_list(
+                    &[&new_aggr_expr],
+                    &[&aggr_arrays],
+                    &mut expr_set,
+                    &mut affected_id,
+                )?;
+                let rewritten = pop_expr(&mut rewritten)?;
+
+                if affected_id.is_empty() {
+                    Some(LogicalPlan::Aggregate(Aggregate::try_new_with_schema(
+                        Arc::new(new_input),
+                        new_group_expr,
+                        new_aggr_expr,
+                        schema.clone(),
+                    )?))
+                } else {
+                    let mut agg_exprs = vec![];
+
+                    for id in affected_id {
+                        match expr_set.get(&id) {
+                            Some((expr, _, _)) => {
+                                // todo: check `nullable`
+                                agg_exprs.push(expr.clone().alias(&id));
+                            }
+                            _ => {
+                                return Err(DataFusionError::Internal(
+                                    "expr_set invalid state".to_string(),
+                                ));
+                            }
+                        }
+                    }
+
+                    let mut proj_exprs = vec![];
+                    for expr in &new_group_expr {
+                        let out_name = expr.to_field(&new_input_schema)?.qualified_name();
+                        proj_exprs.push(Expr::Column(Column::from_name(out_name)));
+                    }
+                    for (expr_rewritten, expr_orig) in
+                        rewritten.into_iter().zip(new_aggr_expr)
+                    {
+                        if expr_rewritten == expr_orig {
+                            if let Expr::Alias(expr, name) = expr_rewritten {
+                                agg_exprs.push(expr.alias(&name));
+                                proj_exprs.push(Expr::Column(Column::from_name(name)));
+                            } else {
+                                let id = ExprIdentifierVisitor::<'static>::desc_expr(
+                                    &expr_rewritten,
+                                );
+                                let out_name = expr_rewritten
+                                    .to_field(&new_input_schema)?
+                                    .qualified_name();
+                                agg_exprs.push(expr_rewritten.alias(&id));
+                                proj_exprs.push(
+                                    Expr::Column(Column::from_name(id)).alias(out_name),
+                                );
+                            }
+                        } else {
+                            proj_exprs.push(expr_rewritten);
+                        }
+                    }
+
+                    let agg = LogicalPlan::Aggregate(Aggregate::try_new(
+                        Arc::new(new_input),
+                        new_group_expr,
+                        agg_exprs,
+                    )?);
+
+                    Some(LogicalPlan::Projection(Projection::try_new(
+                        proj_exprs,
+                        Arc::new(agg),
+                    )?))
+                }
             }
             LogicalPlan::Sort(Sort { expr, input, fetch }) => {
                 let input_schema = Arc::clone(input.schema());
-                let arrays = to_arrays(expr, input_schema, &mut expr_set)?;
+                let arrays =
+                    to_arrays(expr, input_schema, &mut expr_set, ExprMask::Normal)?;
 
                 let (mut new_expr, new_input) =
                     self.rewrite_expr(&[expr], &[&arrays], input, &mut expr_set, config)?;
@@ -282,11 +385,18 @@ fn to_arrays(
     expr: &[Expr],
     input_schema: DFSchemaRef,
     expr_set: &mut ExprSet,
+    expr_mask: ExprMask,
 ) -> Result<Vec<Vec<(usize, String)>>> {
     expr.iter()
         .map(|e| {
             let mut id_array = vec![];
-            expr_to_identifier(e, expr_set, &mut id_array, Arc::clone(&input_schema))?;
+            expr_to_identifier(
+                e,
+                expr_set,
+                &mut id_array,
+                Arc::clone(&input_schema),
+                expr_mask,
+            )?;
 
             Ok(id_array)
         })
@@ -346,6 +456,49 @@ fn build_recover_project_plan(schema: &DFSchema, input: LogicalPlan) -> LogicalP
     )
 }
 
+/// Which type of [expressions](Expr) should be considered for rewriting?
+#[derive(Debug, Clone, Copy)]
+enum ExprMask {
+    /// Ignores:
+    ///
+    /// - [`Literal`](Expr::Literal)
+    /// - [`Columns`](Expr::Column)
+    /// - [`ScalarVariable`](Expr::ScalarVariable)
+    /// - [`Alias`](Expr::Alias)
+    /// - [`Sort`](Expr::Sort)
+    /// - [`Wildcard`](Expr::Wildcard)
+    /// - [`AggregateFunction`](Expr::AggregateFunction)
+    /// - [`AggregateUDF`](Expr::AggregateUDF)
+    Normal,
+
+    /// Like [`Normal`](Self::Normal), but includes [`AggregateFunction`](Expr::AggregateFunction) and [`AggregateUDF`](Expr::AggregateUDF).
+    NormalAndAggregates,
+}
+
+impl ExprMask {
+    fn ignores(&self, expr: &Expr) -> bool {
+        let is_normal_minus_aggregates = matches!(
+            expr,
+            Expr::Literal(..)
+                | Expr::Column(..)
+                | Expr::ScalarVariable(..)
+                | Expr::Alias(..)
+                | Expr::Sort { .. }
+                | Expr::Wildcard
+        );
+
+        let is_aggr = matches!(
+            expr,
+            Expr::AggregateFunction(..) | Expr::AggregateUDF { .. }
+        );
+
+        match self {
+            Self::Normal => is_normal_minus_aggregates || is_aggr,
+            Self::NormalAndAggregates => is_normal_minus_aggregates,
+        }
+    }
+}
+
 /// Go through an expression tree and generate identifier.
 ///
 /// An identifier contains information of the expression itself and its sub-expression.
@@ -379,6 +532,8 @@ struct ExprIdentifierVisitor<'a> {
     node_count: usize,
     /// increased in post_visit, start from 1.
     series_number: usize,
+    /// which expression should be skipped?
+    expr_mask: ExprMask,
 }
 
 /// Record item that used when traversing a expression tree.
@@ -432,15 +587,7 @@ impl TreeNodeVisitor for ExprIdentifierVisitor<'_> {
 
         let (idx, sub_expr_desc) = self.pop_enter_mark();
         // skip exprs should not be recognize.
-        if matches!(
-            expr,
-            Expr::Literal(..)
-                | Expr::Column(..)
-                | Expr::ScalarVariable(..)
-                | Expr::Alias(..)
-                | Expr::Sort { .. }
-                | Expr::Wildcard
-        ) {
+        if self.expr_mask.ignores(expr) {
             self.id_array[idx].0 = self.series_number;
             let desc = Self::desc_expr(expr);
             self.visit_stack.push(VisitRecord::ExprItem(desc));
@@ -468,6 +615,7 @@ fn expr_to_identifier(
     expr_set: &mut ExprSet,
     id_array: &mut Vec<(usize, Identifier)>,
     input_schema: DFSchemaRef,
+    expr_mask: ExprMask,
 ) -> Result<()> {
     expr.visit(&mut ExprIdentifierVisitor {
         expr_set,
@@ -476,6 +624,7 @@ fn expr_to_identifier(
         visit_stack: vec![],
         node_count: 0,
         series_number: 0,
+        expr_mask,
     })?;
 
     Ok(())
@@ -589,8 +738,11 @@ mod test {
     use datafusion_common::DFSchema;
     use datafusion_expr::logical_plan::{table_scan, JoinType};
     use datafusion_expr::{
-        avg, binary_expr, col, lit, logical_plan::builder::LogicalPlanBuilder, sum,
-        Operator,
+        avg, col, lit, logical_plan::builder::LogicalPlanBuilder, sum,
+    };
+    use datafusion_expr::{
+        AccumulatorFunctionImplementation, AggregateUDF, ReturnTypeFunction, Signature,
+        StateTypeFunction, Volatility,
     };
 
     use crate::optimizer::OptimizerContext;
@@ -610,15 +762,7 @@ mod test {
 
     #[test]
     fn id_array_visitor() -> Result<()> {
-        let expr = binary_expr(
-            binary_expr(
-                sum(binary_expr(col("a"), Operator::Plus, lit(1))),
-                Operator::Minus,
-                avg(col("c")),
-            ),
-            Operator::Multiply,
-            lit(2),
-        );
+        let expr = ((sum(col("a") + lit(1))) - avg(col("c"))) * lit(2);
 
         let schema = Arc::new(DFSchema::new_with_metadata(
             vec![
@@ -628,12 +772,40 @@ mod test {
             Default::default(),
         )?);
 
+        // skip aggregates
+        let mut id_array = vec![];
+        expr_to_identifier(
+            &expr,
+            &mut HashMap::new(),
+            &mut id_array,
+            Arc::clone(&schema),
+            ExprMask::Normal,
+        )?;
+
+        let expected = vec![
+            (9, "(SUM(a + Int32(1)) - AVG(c)) * Int32(2)Int32(2)SUM(a + Int32(1)) - AVG(c)AVG(c)SUM(a + Int32(1))"), 
+            (7, "SUM(a + Int32(1)) - AVG(c)AVG(c)SUM(a + Int32(1))"), 
+            (4, ""), 
+            (3, "a + Int32(1)Int32(1)a"), 
+            (1, ""), 
+            (2, ""), 
+            (6, ""), 
+            (5, ""), 
+            (8, "")
+        ]
+        .into_iter()
+        .map(|(number, id)| (number, id.into()))
+        .collect::<Vec<_>>();
+        assert_eq!(expected, id_array);
+
+        // include aggregates
         let mut id_array = vec![];
         expr_to_identifier(
             &expr,
             &mut HashMap::new(),
             &mut id_array,
             Arc::clone(&schema),
+            ExprMask::NormalAndAggregates,
         )?;
 
         let expected = vec![
@@ -671,20 +843,8 @@ mod test {
             .aggregate(
                 iter::empty::<Expr>(),
                 vec![
-                    sum(binary_expr(
-                        col("a"),
-                        Operator::Multiply,
-                        binary_expr(lit(1), Operator::Minus, col("b")),
-                    )),
-                    sum(binary_expr(
-                        binary_expr(
-                            col("a"),
-                            Operator::Multiply,
-                            binary_expr(lit(1), Operator::Minus, col("b")),
-                        ),
-                        Operator::Multiply,
-                        binary_expr(lit(1), Operator::Plus, col("c")),
-                    )),
+                    sum(col("a") * (lit(1) - col("b"))),
+                    sum((col("a") * (lit(1) - col("b"))) * (lit(1) + col("c"))),
                 ],
             )?
             .build()?;
@@ -702,22 +862,127 @@ mod test {
     fn aggregate() -> Result<()> {
         let table_scan = test_table_scan()?;
 
-        let plan = LogicalPlanBuilder::from(table_scan)
+        let return_type: ReturnTypeFunction = Arc::new(|inputs| {
+            assert_eq!(inputs, &[DataType::UInt32]);
+            Ok(Arc::new(DataType::UInt32))
+        });
+        let accumulator: AccumulatorFunctionImplementation =
+            Arc::new(|_| unimplemented!());
+        let state_type: StateTypeFunction = Arc::new(|_| unimplemented!());
+        let udf_agg = |inner: Expr| Expr::AggregateUDF {
+            fun: Arc::new(AggregateUDF::new(
+                "my_agg",
+                &Signature::exact(vec![DataType::UInt32], Volatility::Stable),
+                &return_type,
+                &accumulator,
+                &state_type,
+            )),
+            args: vec![inner],
+            filter: None,
+        };
+
+        // test: common aggregates
+        let plan = LogicalPlanBuilder::from(table_scan.clone())
+            .aggregate(
+                iter::empty::<Expr>(),
+                vec![
+                    // common: avg(col("a"))
+                    avg(col("a")).alias("col1"),
+                    avg(col("a")).alias("col2"),
+                    // no common
+                    avg(col("b")).alias("col3"),
+                    avg(col("c")),
+                    // common: udf_agg(col("a"))
+                    udf_agg(col("a")).alias("col4"),
+                    udf_agg(col("a")).alias("col5"),
+                    // no common
+                    udf_agg(col("b")).alias("col6"),
+                    udf_agg(col("c")),
+                ],
+            )?
+            .build()?;
+
+        let expected = "Projection: AVG(test.a)test.a AS AVG(test.a) AS col1, AVG(test.a)test.a AS AVG(test.a) AS col2, col3, AVG(test.c) AS AVG(test.c), my_agg(test.a)test.a AS my_agg(test.a) AS col4, my_agg(test.a)test.a AS my_agg(test.a) AS col5, col6, my_agg(test.c) AS my_agg(test.c)\
+        \n  Aggregate: groupBy=[[]], aggr=[[AVG(test.a) AS AVG(test.a)test.a, my_agg(test.a) AS my_agg(test.a)test.a, AVG(test.b) AS col3, AVG(test.c) AS AVG(test.c), my_agg(test.b) AS col6, my_agg(test.c) AS my_agg(test.c)]]\
+        \n    TableScan: test";
+
+        assert_optimized_plan_eq(expected, &plan);
+
+        // test: trafo after aggregate
+        let plan = LogicalPlanBuilder::from(table_scan.clone())
+            .aggregate(
+                iter::empty::<Expr>(),
+                vec![
+                    lit(1) + avg(col("a")),
+                    lit(1) - avg(col("a")),
+                    lit(1) + udf_agg(col("a")),
+                    lit(1) - udf_agg(col("a")),
+                ],
+            )?
+            .build()?;
+
+        let expected = "Projection: Int32(1) + AVG(test.a)test.a AS AVG(test.a), Int32(1) - AVG(test.a)test.a AS AVG(test.a), Int32(1) + my_agg(test.a)test.a AS my_agg(test.a), Int32(1) - my_agg(test.a)test.a AS my_agg(test.a)\
+        \n  Aggregate: groupBy=[[]], aggr=[[AVG(test.a) AS AVG(test.a)test.a, my_agg(test.a) AS my_agg(test.a)test.a]]\
+        \n    TableScan: test";
+
+        assert_optimized_plan_eq(expected, &plan);
+
+        // test: transformation before aggregate
+        let plan = LogicalPlanBuilder::from(table_scan.clone())
             .aggregate(
                 iter::empty::<Expr>(),
                 vec![
-                    binary_expr(lit(1), Operator::Plus, avg(col("a"))),
-                    binary_expr(lit(1), Operator::Minus, avg(col("a"))),
+                    avg(lit(1u32) + col("a")).alias("col1"),
+                    udf_agg(lit(1u32) + col("a")).alias("col2"),
                 ],
             )?
             .build()?;
 
-        let expected = "Aggregate: groupBy=[[]], aggr=[[Int32(1) + AVG(test.a)test.a AS AVG(test.a), Int32(1) - AVG(test.a)test.a AS AVG(test.a)]]\
-        \n  Projection: AVG(test.a) AS AVG(test.a)test.a, test.a, test.b, test.c\
+        let expected = "Aggregate: groupBy=[[]], aggr=[[AVG(UInt32(1) + test.atest.aUInt32(1) AS UInt32(1) + test.a) AS col1, my_agg(UInt32(1) + test.atest.aUInt32(1) AS UInt32(1) + test.a) AS col2]]\
+        \n  Projection: UInt32(1) + test.a AS UInt32(1) + test.atest.aUInt32(1), test.a, test.b, test.c\
         \n    TableScan: test";
 
         assert_optimized_plan_eq(expected, &plan);
 
+        // test: common between agg and group
+        let plan = LogicalPlanBuilder::from(table_scan.clone())
+            .aggregate(
+                vec![lit(1u32) + col("a")],
+                vec![
+                    avg(lit(1u32) + col("a")).alias("col1"),
+                    udf_agg(lit(1u32) + col("a")).alias("col2"),
+                ],
+            )?
+            .build()?;
+
+        let expected = "Aggregate: groupBy=[[UInt32(1) + test.atest.aUInt32(1) AS UInt32(1) + test.a]], aggr=[[AVG(UInt32(1) + test.atest.aUInt32(1) AS UInt32(1) + test.a) AS col1, my_agg(UInt32(1) + test.atest.aUInt32(1) AS UInt32(1) + test.a) AS col2]]\
+        \n  Projection: UInt32(1) + test.a AS UInt32(1) + test.atest.aUInt32(1), test.a, test.b, test.c\
+        \n    TableScan: test";
+
+        assert_optimized_plan_eq(expected, &plan);
+
+        // test: all mixed
+        let plan = LogicalPlanBuilder::from(table_scan)
+            .aggregate(
+                vec![lit(1u32) + col("a")],
+                vec![
+                    (lit(1u32) + avg(lit(1u32) + col("a"))).alias("col1"),
+                    (lit(1u32) - avg(lit(1u32) + col("a"))).alias("col2"),
+                    avg(lit(1u32) + col("a")),
+                    (lit(1u32) + udf_agg(lit(1u32) + col("a"))).alias("col3"),
+                    (lit(1u32) - udf_agg(lit(1u32) + col("a"))).alias("col4"),
+                    udf_agg(lit(1u32) + col("a")),
+                ],
+            )?
+            .build()?;
+
+        let expected = "Projection: UInt32(1) + test.a, UInt32(1) + AVG(UInt32(1) + test.atest.aUInt32(1) AS UInt32(1) + test.a)UInt32(1) + test.atest.aUInt32(1) AS UInt32(1) + test.a AS AVG(UInt32(1) + test.a) AS col1, UInt32(1) - AVG(UInt32(1) + test.atest.aUInt32(1) AS UInt32(1) + test.a)UInt32(1) + test.atest.aUInt32(1) AS UInt32(1) + test.a AS AVG(UInt32(1) + test.a) AS col2, AVG(UInt32(1) + test.atest.aUInt32(1) AS UInt32(1) + test.a)UInt32(1) + test.atest.aUInt32(1) AS UInt32(1) + [...]
+        \n  Aggregate: groupBy=[[UInt32(1) + test.atest.aUInt32(1) AS UInt32(1) + test.a]], aggr=[[AVG(UInt32(1) + test.atest.aUInt32(1) AS UInt32(1) + test.a) AS AVG(UInt32(1) + test.atest.aUInt32(1) AS UInt32(1) + test.a)UInt32(1) + test.atest.aUInt32(1) AS UInt32(1) + test.a, my_agg(UInt32(1) + test.atest.aUInt32(1) AS UInt32(1) + test.a) AS my_agg(UInt32(1) + test.atest.aUInt32(1) AS UInt32(1) + test.a)UInt32(1) + test.atest.aUInt32(1) AS UInt32(1) + test.a]]\
+        \n    Projection: UInt32(1) + test.a AS UInt32(1) + test.atest.aUInt32(1), test.a, test.b, test.c\
+        \n      TableScan: test";
+
+        assert_optimized_plan_eq(expected, &plan);
+
         Ok(())
     }
 
@@ -727,8 +992,8 @@ mod test {
 
         let plan = LogicalPlanBuilder::from(table_scan)
             .project(vec![
-                binary_expr(lit(1), Operator::Plus, col("a")).alias("first"),
-                binary_expr(lit(1), Operator::Plus, col("a")).alias("second"),
+                (lit(1) + col("a")).alias("first"),
+                (lit(1) + col("a")).alias("second"),
             ])?
             .build()?;
 
@@ -746,10 +1011,7 @@ mod test {
         let table_scan = test_table_scan()?;
 
         let plan = LogicalPlanBuilder::from(table_scan)
-            .project(vec![
-                binary_expr(lit(1), Operator::Plus, col("a")),
-                binary_expr(col("a"), Operator::Plus, lit(1)),
-            ])?
+            .project(vec![lit(1) + col("a"), col("a") + lit(1)])?
             .build()?;
 
         let expected = "Projection: Int32(1) + test.a, test.a + Int32(1)\
@@ -765,8 +1027,8 @@ mod test {
         let table_scan = test_table_scan()?;
 
         let plan = LogicalPlanBuilder::from(table_scan)
-            .project(vec![binary_expr(lit(1), Operator::Plus, col("a"))])?
-            .project(vec![binary_expr(lit(1), Operator::Plus, col("a"))])?
+            .project(vec![lit(1) + col("a")])?
+            .project(vec![lit(1) + col("a")])?
             .build()?;
 
         let expected = "Projection: Int32(1) + test.a\