You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by ag...@apache.org on 2022/10/08 00:18:15 UTC

[arrow-datafusion] branch master updated: Refactor `Expr::Case` to use a struct (#3757)

This is an automated email from the ASF dual-hosted git repository.

agrove pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/master by this push:
     new 488b2cec3 Refactor `Expr::Case` to use a struct (#3757)
488b2cec3 is described below

commit 488b2cec3c700821dfdfece2d85c4cd7956e718d
Author: Andy Grove <an...@gmail.com>
AuthorDate: Fri Oct 7 18:18:09 2022 -0600

    Refactor `Expr::Case` to use a struct (#3757)
---
 datafusion/core/src/dataframe.rs                 |  1 -
 datafusion/core/src/physical_plan/planner.rs     | 12 ++--
 datafusion/expr/src/conditional_expressions.rs   | 12 ++--
 datafusion/expr/src/expr.rs                      | 60 +++++++++++---------
 datafusion/expr/src/expr_rewriter.rs             | 21 +++----
 datafusion/expr/src/expr_schema.rs               | 13 ++---
 datafusion/expr/src/expr_visitor.rs              | 12 ++--
 datafusion/optimizer/src/simplify_expressions.rs | 71 +++++++++++-------------
 datafusion/optimizer/src/type_coercion.rs        | 24 +++-----
 datafusion/physical-expr/src/planner.rs          | 36 ++++++------
 datafusion/proto/src/from_proto.rs               | 10 ++--
 datafusion/proto/src/lib.rs                      | 22 ++++----
 datafusion/proto/src/to_proto.rs                 | 12 ++--
 datafusion/sql/src/planner.rs                    |  8 +--
 datafusion/sql/src/utils.rs                      | 16 ++----
 15 files changed, 150 insertions(+), 180 deletions(-)

diff --git a/datafusion/core/src/dataframe.rs b/datafusion/core/src/dataframe.rs
index 06768b563..a5caad176 100644
--- a/datafusion/core/src/dataframe.rs
+++ b/datafusion/core/src/dataframe.rs
@@ -827,7 +827,6 @@ impl TableProvider for DataFrame {
 
 #[cfg(test)]
 mod tests {
-    use arrow::array::Int32Array;
     use std::vec;
 
     use super::*;
diff --git a/datafusion/core/src/physical_plan/planner.rs b/datafusion/core/src/physical_plan/planner.rs
index 86d005776..8bb1d95a4 100644
--- a/datafusion/core/src/physical_plan/planner.rs
+++ b/datafusion/core/src/physical_plan/planner.rs
@@ -111,19 +111,15 @@ fn create_physical_name(e: &Expr, is_first_expr: bool) -> Result<String> {
             let right = create_physical_name(right, false)?;
             Ok(format!("{} {} {}", left, op, right))
         }
-        Expr::Case {
-            expr,
-            when_then_expr,
-            else_expr,
-        } => {
+        Expr::Case(case) => {
             let mut name = "CASE ".to_string();
-            if let Some(e) = expr {
+            if let Some(e) = &case.expr {
                 let _ = write!(name, "{:?} ", e);
             }
-            for (w, t) in when_then_expr {
+            for (w, t) in &case.when_then_expr {
                 let _ = write!(name, "WHEN {:?} THEN {:?} ", w, t);
             }
-            if let Some(e) = else_expr {
+            if let Some(e) = &case.else_expr {
                 let _ = write!(name, "ELSE {:?} ", e);
             }
             name += "END";
diff --git a/datafusion/expr/src/conditional_expressions.rs b/datafusion/expr/src/conditional_expressions.rs
index 0c5104a4b..28ac6e8cd 100644
--- a/datafusion/expr/src/conditional_expressions.rs
+++ b/datafusion/expr/src/conditional_expressions.rs
@@ -15,6 +15,7 @@
 // specific language governing permissions and limitations
 // under the License.
 
+use crate::expr::Case;
 ///! Conditional expressions
 use crate::{expr_schema::ExprSchemable, Expr};
 use arrow::datatypes::DataType;
@@ -108,16 +109,15 @@ impl CaseBuilder {
             }
         }
 
-        Ok(Expr::Case {
-            expr: self.expr.clone(),
-            when_then_expr: self
-                .when_expr
+        Ok(Expr::Case(Case::new(
+            self.expr.clone(),
+            self.when_expr
                 .iter()
                 .zip(self.then_expr.iter())
                 .map(|(w, t)| (Box::new(w.clone()), Box::new(t.clone())))
                 .collect(),
-            else_expr: self.else_expr.clone(),
-        })
+            self.else_expr.clone(),
+        )))
     }
 }
 
diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs
index 008a2c454..c131682a8 100644
--- a/datafusion/expr/src/expr.rs
+++ b/datafusion/expr/src/expr.rs
@@ -176,14 +176,7 @@ pub enum Expr {
     ///     [WHEN ...]
     ///     [ELSE result]
     /// END
-    Case {
-        /// Optional base expression that can be compared to literal values in the "when" expressions
-        expr: Option<Box<Expr>>,
-        /// One or more when/then expressions
-        when_then_expr: Vec<(Box<Expr>, Box<Expr>)>,
-        /// Optional "else" expression
-        else_expr: Option<Box<Expr>>,
-    },
+    Case(Case),
     /// Casts the expression to a given type and will return a runtime error if the expression cannot be cast.
     /// This expression is guaranteed to have a fixed type.
     Cast {
@@ -292,6 +285,32 @@ pub enum Expr {
     GroupingSet(GroupingSet),
 }
 
+/// CASE expression
+#[derive(Clone, PartialEq, Eq, Hash)]
+pub struct Case {
+    /// Optional base expression that can be compared to literal values in the "when" expressions
+    pub expr: Option<Box<Expr>>,
+    /// One or more when/then expressions
+    pub when_then_expr: Vec<(Box<Expr>, Box<Expr>)>,
+    /// Optional "else" expression
+    pub else_expr: Option<Box<Expr>>,
+}
+
+impl Case {
+    /// Create a new Case expression
+    pub fn new(
+        expr: Option<Box<Expr>>,
+        when_then_expr: Vec<(Box<Expr>, Box<Expr>)>,
+        else_expr: Option<Box<Expr>>,
+    ) -> Self {
+        Self {
+            expr,
+            when_then_expr,
+            else_expr,
+        }
+    }
+}
+
 /// Grouping sets
 /// See https://www.postgresql.org/docs/current/queries-table-expressions.html#QUERIES-GROUPING-SETS
 /// for Postgres definition.
@@ -601,20 +620,15 @@ impl fmt::Debug for Expr {
             Expr::Column(c) => write!(f, "{}", c),
             Expr::ScalarVariable(_, var_names) => write!(f, "{}", var_names.join(".")),
             Expr::Literal(v) => write!(f, "{:?}", v),
-            Expr::Case {
-                expr,
-                when_then_expr,
-                else_expr,
-                ..
-            } => {
+            Expr::Case(case) => {
                 write!(f, "CASE ")?;
-                if let Some(e) = expr {
+                if let Some(e) = &case.expr {
                     write!(f, "{:?} ", e)?;
                 }
-                for (w, t) in when_then_expr {
+                for (w, t) in &case.when_then_expr {
                     write!(f, "WHEN {:?} THEN {:?} ", w, t)?;
                 }
-                if let Some(e) = else_expr {
+                if let Some(e) = &case.else_expr {
                     write!(f, "ELSE {:?} ", e)?;
                 }
                 write!(f, "END")
@@ -957,22 +971,18 @@ fn create_name(e: &Expr) -> Result<String> {
             );
             Ok(s)
         }
-        Expr::Case {
-            expr,
-            when_then_expr,
-            else_expr,
-        } => {
+        Expr::Case(case) => {
             let mut name = "CASE ".to_string();
-            if let Some(e) = expr {
+            if let Some(e) = &case.expr {
                 let e = create_name(e)?;
                 let _ = write!(name, "{} ", e);
             }
-            for (w, t) in when_then_expr {
+            for (w, t) in &case.when_then_expr {
                 let when = create_name(w)?;
                 let then = create_name(t)?;
                 let _ = write!(name, "WHEN {} THEN {} ", when, then);
             }
-            if let Some(e) = else_expr {
+            if let Some(e) = &case.else_expr {
                 let e = create_name(e)?;
                 let _ = write!(name, "ELSE {} ", e);
             }
diff --git a/datafusion/expr/src/expr_rewriter.rs b/datafusion/expr/src/expr_rewriter.rs
index 6bdb54522..427fcf170 100644
--- a/datafusion/expr/src/expr_rewriter.rs
+++ b/datafusion/expr/src/expr_rewriter.rs
@@ -17,7 +17,7 @@
 
 //! Expression rewriter
 
-use crate::expr::GroupingSet;
+use crate::expr::{Case, GroupingSet};
 use crate::logical_plan::{Aggregate, Projection};
 use crate::utils::{from_plan, grouping_set_to_exprlist};
 use crate::{Expr, ExprSchemable, LogicalPlan};
@@ -184,13 +184,10 @@ impl ExprRewritable for Expr {
                 high: rewrite_boxed(high, rewriter)?,
                 negated,
             },
-            Expr::Case {
-                expr,
-                when_then_expr,
-                else_expr,
-            } => {
-                let expr = rewrite_option_box(expr, rewriter)?;
-                let when_then_expr = when_then_expr
+            Expr::Case(case) => {
+                let expr = rewrite_option_box(case.expr, rewriter)?;
+                let when_then_expr = case
+                    .when_then_expr
                     .into_iter()
                     .map(|(when, then)| {
                         Ok((
@@ -200,13 +197,9 @@ impl ExprRewritable for Expr {
                     })
                     .collect::<Result<Vec<_>>>()?;
 
-                let else_expr = rewrite_option_box(else_expr, rewriter)?;
+                let else_expr = rewrite_option_box(case.else_expr, rewriter)?;
 
-                Expr::Case {
-                    expr,
-                    when_then_expr,
-                    else_expr,
-                }
+                Expr::Case(Case::new(expr, when_then_expr, else_expr))
             }
             Expr::Cast { expr, data_type } => Expr::Cast {
                 expr: rewrite_boxed(expr, rewriter)?,
diff --git a/datafusion/expr/src/expr_schema.rs b/datafusion/expr/src/expr_schema.rs
index 88d767366..5442a2421 100644
--- a/datafusion/expr/src/expr_schema.rs
+++ b/datafusion/expr/src/expr_schema.rs
@@ -59,7 +59,7 @@ impl ExprSchemable for Expr {
             Expr::Column(c) => Ok(schema.data_type(c)?.clone()),
             Expr::ScalarVariable(ty, _) => Ok(ty.clone()),
             Expr::Literal(l) => Ok(l.get_datatype()),
-            Expr::Case { when_then_expr, .. } => when_then_expr[0].1.get_type(schema),
+            Expr::Case(case) => case.when_then_expr[0].1.get_type(schema),
             Expr::Cast { data_type, .. } | Expr::TryCast { data_type, .. } => {
                 Ok(data_type.clone())
             }
@@ -164,19 +164,16 @@ impl ExprSchemable for Expr {
             | Expr::InList { expr, .. } => expr.nullable(input_schema),
             Expr::Column(c) => input_schema.nullable(c),
             Expr::Literal(value) => Ok(value.is_null()),
-            Expr::Case {
-                when_then_expr,
-                else_expr,
-                ..
-            } => {
+            Expr::Case(case) => {
                 // this expression is nullable if any of the input expressions are nullable
-                let then_nullable = when_then_expr
+                let then_nullable = case
+                    .when_then_expr
                     .iter()
                     .map(|(_, t)| t.nullable(input_schema))
                     .collect::<Result<Vec<_>>>()?;
                 if then_nullable.contains(&true) {
                     Ok(true)
-                } else if let Some(e) = else_expr {
+                } else if let Some(e) = &case.else_expr {
                     e.nullable(input_schema)
                 } else {
                     // CASE produces NULL if there is no `else` expr
diff --git a/datafusion/expr/src/expr_visitor.rs b/datafusion/expr/src/expr_visitor.rs
index 3885456cc..f362a759e 100644
--- a/datafusion/expr/src/expr_visitor.rs
+++ b/datafusion/expr/src/expr_visitor.rs
@@ -153,24 +153,20 @@ impl ExprVisitable for Expr {
                 let visitor = low.accept(visitor)?;
                 high.accept(visitor)
             }
-            Expr::Case {
-                expr,
-                when_then_expr,
-                else_expr,
-            } => {
-                let visitor = if let Some(expr) = expr.as_ref() {
+            Expr::Case(case) => {
+                let visitor = if let Some(expr) = case.expr.as_ref() {
                     expr.accept(visitor)
                 } else {
                     Ok(visitor)
                 }?;
-                let visitor = when_then_expr.iter().try_fold(
+                let visitor = case.when_then_expr.iter().try_fold(
                     visitor,
                     |visitor, (when, then)| {
                         let visitor = when.accept(visitor)?;
                         then.accept(visitor)
                     },
                 )?;
-                if let Some(else_expr) = else_expr.as_ref() {
+                if let Some(else_expr) = case.else_expr.as_ref() {
                     else_expr.accept(visitor)
                 } else {
                     Ok(visitor)
diff --git a/datafusion/optimizer/src/simplify_expressions.rs b/datafusion/optimizer/src/simplify_expressions.rs
index bb17a9925..c96f6eea7 100644
--- a/datafusion/optimizer/src/simplify_expressions.rs
+++ b/datafusion/optimizer/src/simplify_expressions.rs
@@ -490,7 +490,7 @@ impl<'a> ConstEvaluator<'a> {
             | Expr::Like { .. }
             | Expr::ILike { .. }
             | Expr::SimilarTo { .. }
-            | Expr::Case { .. }
+            | Expr::Case(_)
             | Expr::Cast { .. }
             | Expr::TryCast { .. }
             | Expr::InList { .. }
@@ -848,20 +848,17 @@ impl<'a, S: SimplifyInfo> ExprRewriter for Simplifier<'a, S> {
             //
             // Note: the rationale for this rewrite is that the expr can then be further
             // simplified using the existing rules for AND/OR
-            Case {
-                expr: None,
-                when_then_expr,
-                else_expr,
-            } if !when_then_expr.is_empty()
-                && when_then_expr.len() < 3 // The rewrite is O(n!) so limit to small number
-                && info.is_boolean_type(&when_then_expr[0].1)? =>
+            Case(case)
+                if !case.when_then_expr.is_empty()
+                && case.when_then_expr.len() < 3 // The rewrite is O(n!) so limit to small number
+                && info.is_boolean_type(&case.when_then_expr[0].1)? =>
             {
                 // The disjunction of all the when predicates encountered so far
                 let mut filter_expr = lit(false);
                 // The disjunction of all the cases
                 let mut out_expr = lit(false);
 
-                for (when, then) in when_then_expr {
+                for (when, then) in case.when_then_expr {
                     let case_expr = when
                         .as_ref()
                         .clone()
@@ -872,7 +869,7 @@ impl<'a, S: SimplifyInfo> ExprRewriter for Simplifier<'a, S> {
                     filter_expr = filter_expr.or(*when);
                 }
 
-                if let Some(else_expr) = else_expr {
+                if let Some(else_expr) = case.else_expr {
                     let case_expr = filter_expr.not().and(*else_expr);
                     out_expr = out_expr.or(case_expr);
                 }
@@ -974,6 +971,7 @@ mod tests {
     use arrow::array::{ArrayRef, Int32Array};
     use chrono::{DateTime, TimeZone, Utc};
     use datafusion_common::{DFField, ToDFSchema};
+    use datafusion_expr::expr::Case;
     use datafusion_expr::logical_plan::table_scan;
     use datafusion_expr::{
         and, binary_expr, call_fn, col, create_udf, lit, lit_timestamp_nano,
@@ -1700,14 +1698,14 @@ mod tests {
         // -->
         // false
         assert_eq!(
-            simplify(Expr::Case {
-                expr: None,
-                when_then_expr: vec![(
+            simplify(Expr::Case(Case::new(
+                None,
+                vec![(
                     Box::new(col("c2").not_eq(lit(false))),
                     Box::new(lit("ok").eq(lit("not_ok"))),
                 )],
-                else_expr: Some(Box::new(col("c2").eq(lit(true)))),
-            }),
+                Some(Box::new(col("c2").eq(lit(true)))),
+            ))),
             col("c2").not().and(col("c2")) // #1716
         );
 
@@ -1720,14 +1718,14 @@ mod tests {
         // Need to call simplify 2x due to
         // https://github.com/apache/arrow-datafusion/issues/1160
         assert_eq!(
-            simplify(simplify(Expr::Case {
-                expr: None,
-                when_then_expr: vec![(
+            simplify(simplify(Expr::Case(Case::new(
+                None,
+                vec![(
                     Box::new(col("c2").not_eq(lit(false))),
                     Box::new(lit("ok").eq(lit("ok"))),
                 )],
-                else_expr: Some(Box::new(col("c2").eq(lit(true)))),
-            })),
+                Some(Box::new(col("c2").eq(lit(true)))),
+            )))),
             col("c2").or(col("c2").not().and(col("c2"))) // #1716
         );
 
@@ -1738,14 +1736,11 @@ mod tests {
         // Need to call simplify 2x due to
         // https://github.com/apache/arrow-datafusion/issues/1160
         assert_eq!(
-            simplify(simplify(Expr::Case {
-                expr: None,
-                when_then_expr: vec![(
-                    Box::new(col("c2").is_null()),
-                    Box::new(lit(true)),
-                )],
-                else_expr: Some(Box::new(col("c2"))),
-            })),
+            simplify(simplify(Expr::Case(Case::new(
+                None,
+                vec![(Box::new(col("c2").is_null()), Box::new(lit(true)),)],
+                Some(Box::new(col("c2"))),
+            )))),
             col("c2")
                 .is_null()
                 .or(col("c2").is_not_null().and(col("c2")))
@@ -1759,14 +1754,14 @@ mod tests {
         // Need to call simplify 2x due to
         // https://github.com/apache/arrow-datafusion/issues/1160
         assert_eq!(
-            simplify(simplify(Expr::Case {
-                expr: None,
-                when_then_expr: vec![
+            simplify(simplify(Expr::Case(Case::new(
+                None,
+                vec![
                     (Box::new(col("c1")), Box::new(lit(true)),),
                     (Box::new(col("c2")), Box::new(lit(false)),),
                 ],
-                else_expr: Some(Box::new(lit(true))),
-            })),
+                Some(Box::new(lit(true))),
+            )))),
             col("c1").or(col("c1").not().and(col("c2").not()))
         );
 
@@ -1778,14 +1773,14 @@ mod tests {
         // Need to call simplify 2x due to
         // https://github.com/apache/arrow-datafusion/issues/1160
         assert_eq!(
-            simplify(simplify(Expr::Case {
-                expr: None,
-                when_then_expr: vec![
+            simplify(simplify(Expr::Case(Case::new(
+                None,
+                vec![
                     (Box::new(col("c1")), Box::new(lit(true)),),
                     (Box::new(col("c2")), Box::new(lit(false)),),
                 ],
-                else_expr: Some(Box::new(lit(true))),
-            })),
+                Some(Box::new(lit(true))),
+            )))),
             col("c1").or(col("c1").not().and(col("c2").not()))
         );
     }
diff --git a/datafusion/optimizer/src/type_coercion.rs b/datafusion/optimizer/src/type_coercion.rs
index f0470da87..fcfe6eaaa 100644
--- a/datafusion/optimizer/src/type_coercion.rs
+++ b/datafusion/optimizer/src/type_coercion.rs
@@ -20,6 +20,7 @@
 use crate::{OptimizerConfig, OptimizerRule};
 use arrow::datatypes::DataType;
 use datafusion_common::{DFSchema, DFSchemaRef, DataFusionError, Result};
+use datafusion_expr::expr::Case;
 use datafusion_expr::expr_rewriter::{ExprRewritable, ExprRewriter, RewriteRecursion};
 use datafusion_expr::logical_plan::Subquery;
 use datafusion_expr::type_coercion::binary::{coerce_types, comparison_coercion};
@@ -357,18 +358,15 @@ impl ExprRewriter for TypeCoercionRewriter {
                     }
                 }
             }
-            Expr::Case {
-                expr,
-                when_then_expr,
-                else_expr,
-            } => {
+            Expr::Case(case) => {
                 // all the result of then and else should be convert to a common data type,
                 // if they can be coercible to a common data type, return error.
-                let then_types = when_then_expr
+                let then_types = case
+                    .when_then_expr
                     .iter()
                     .map(|when_then| when_then.1.get_type(&self.schema))
                     .collect::<Result<Vec<_>>>()?;
-                let else_type = match &else_expr {
+                let else_type = match &case.else_expr {
                     None => Ok(None),
                     Some(expr) => expr.get_type(&self.schema).map(Some),
                 }?;
@@ -380,24 +378,20 @@ impl ExprRewriter for TypeCoercionRewriter {
                         then_types, else_type
                     ))),
                     Some(data_type) => {
-                        let left = when_then_expr
+                        let left = case.when_then_expr
                             .into_iter()
                             .map(|(when, then)| {
                                 let then = then.cast_to(&data_type, &self.schema)?;
                                 Ok((when, Box::new(then)))
                             })
                             .collect::<Result<Vec<_>>>()?;
-                        let right = match else_expr {
+                        let right = match &case.else_expr {
                             None => None,
                             Some(expr) => {
-                                Some(Box::new(expr.cast_to(&data_type, &self.schema)?))
+                                Some(Box::new(expr.clone().cast_to(&data_type, &self.schema)?))
                             }
                         };
-                        Ok(Expr::Case {
-                            expr,
-                            when_then_expr: left,
-                            else_expr: right,
-                        })
+                        Ok(Expr::Case(Case::new(case.expr,left,right)))
                     }
                 }
             }
diff --git a/datafusion/physical-expr/src/planner.rs b/datafusion/physical-expr/src/planner.rs
index 0964d6480..993891884 100644
--- a/datafusion/physical-expr/src/planner.rs
+++ b/datafusion/physical-expr/src/planner.rs
@@ -221,13 +221,8 @@ pub fn create_physical_expr(
                 binary_expr(expr.as_ref().clone(), op, pattern.as_ref().clone());
             create_physical_expr(&bin_expr, input_dfschema, input_schema, execution_props)
         }
-        Expr::Case {
-            expr,
-            when_then_expr,
-            else_expr,
-            ..
-        } => {
-            let expr: Option<Arc<dyn PhysicalExpr>> = if let Some(e) = expr {
+        Expr::Case(case) => {
+            let expr: Option<Arc<dyn PhysicalExpr>> = if let Some(e) = &case.expr {
                 Some(create_physical_expr(
                     e.as_ref(),
                     input_dfschema,
@@ -237,7 +232,8 @@ pub fn create_physical_expr(
             } else {
                 None
             };
-            let when_expr = when_then_expr
+            let when_expr = case
+                .when_then_expr
                 .iter()
                 .map(|(w, _)| {
                     create_physical_expr(
@@ -248,7 +244,8 @@ pub fn create_physical_expr(
                     )
                 })
                 .collect::<Result<Vec<_>>>()?;
-            let then_expr = when_then_expr
+            let then_expr = case
+                .when_then_expr
                 .iter()
                 .map(|(_, t)| {
                     create_physical_expr(
@@ -265,16 +262,17 @@ pub fn create_physical_expr(
                     .zip(then_expr.iter())
                     .map(|(w, t)| (w.clone(), t.clone()))
                     .collect();
-            let else_expr: Option<Arc<dyn PhysicalExpr>> = if let Some(e) = else_expr {
-                Some(create_physical_expr(
-                    e.as_ref(),
-                    input_dfschema,
-                    input_schema,
-                    execution_props,
-                )?)
-            } else {
-                None
-            };
+            let else_expr: Option<Arc<dyn PhysicalExpr>> =
+                if let Some(e) = &case.else_expr {
+                    Some(create_physical_expr(
+                        e.as_ref(),
+                        input_dfschema,
+                        input_schema,
+                        execution_props,
+                    )?)
+                } else {
+                    None
+                };
             Ok(expressions::case(expr, when_then_expr, else_expr)?)
         }
         Expr::Cast { expr, data_type } => expressions::cast(
diff --git a/datafusion/proto/src/from_proto.rs b/datafusion/proto/src/from_proto.rs
index 3eeb30edf..208c24036 100644
--- a/datafusion/proto/src/from_proto.rs
+++ b/datafusion/proto/src/from_proto.rs
@@ -31,8 +31,8 @@ use datafusion::logical_plan::FunctionRegistry;
 use datafusion_common::{
     Column, DFField, DFSchema, DFSchemaRef, DataFusionError, ScalarValue,
 };
-use datafusion_expr::expr::GroupingSet;
 use datafusion_expr::expr::GroupingSet::GroupingSets;
+use datafusion_expr::expr::{Case, GroupingSet};
 use datafusion_expr::{
     abs, acos, array, ascii, asin, atan, atan2, bit_length, btrim, ceil,
     character_length, chr, coalesce, concat_expr, concat_ws_expr, cos, date_bin,
@@ -1023,11 +1023,11 @@ pub fn parse_expr(
                     Ok((Box::new(when_expr), Box::new(then_expr)))
                 })
                 .collect::<Result<Vec<(Box<Expr>, Box<Expr>)>, Error>>()?;
-            Ok(Expr::Case {
-                expr: parse_optional_expr(&case.expr, registry)?.map(Box::new),
+            Ok(Expr::Case(Case::new(
+                parse_optional_expr(&case.expr, registry)?.map(Box::new),
                 when_then_expr,
-                else_expr: parse_optional_expr(&case.else_expr, registry)?.map(Box::new),
-            })
+                parse_optional_expr(&case.else_expr, registry)?.map(Box::new),
+            )))
         }
         ExprType::Cast(cast) => {
             let expr = Box::new(parse_required_expr(&cast.expr, registry, "expr")?);
diff --git a/datafusion/proto/src/lib.rs b/datafusion/proto/src/lib.rs
index e3b6c848a..7a495077f 100644
--- a/datafusion/proto/src/lib.rs
+++ b/datafusion/proto/src/lib.rs
@@ -63,7 +63,7 @@ mod roundtrip_tests {
     use datafusion::physical_plan::functions::make_scalar_function;
     use datafusion::prelude::{create_udf, CsvReadOptions, SessionContext};
     use datafusion_common::{DFSchemaRef, DataFusionError, ScalarValue};
-    use datafusion_expr::expr::GroupingSet;
+    use datafusion_expr::expr::{Case, GroupingSet};
     use datafusion_expr::logical_plan::{Extension, UserDefinedLogicalNode};
     use datafusion_expr::{
         col, lit, Accumulator, AggregateFunction, AggregateState,
@@ -970,11 +970,11 @@ mod roundtrip_tests {
 
     #[test]
     fn roundtrip_case() {
-        let test_expr = Expr::Case {
-            expr: Some(Box::new(lit(1.0_f32))),
-            when_then_expr: vec![(Box::new(lit(2.0_f32)), Box::new(lit(3.0_f32)))],
-            else_expr: Some(Box::new(lit(4.0_f32))),
-        };
+        let test_expr = Expr::Case(Case::new(
+            Some(Box::new(lit(1.0_f32))),
+            vec![(Box::new(lit(2.0_f32)), Box::new(lit(3.0_f32)))],
+            Some(Box::new(lit(4.0_f32))),
+        ));
 
         let ctx = SessionContext::new();
         roundtrip_expr_test(test_expr, ctx);
@@ -982,11 +982,11 @@ mod roundtrip_tests {
 
     #[test]
     fn roundtrip_case_with_null() {
-        let test_expr = Expr::Case {
-            expr: Some(Box::new(lit(1.0_f32))),
-            when_then_expr: vec![(Box::new(lit(2.0_f32)), Box::new(lit(3.0_f32)))],
-            else_expr: Some(Box::new(Expr::Literal(ScalarValue::Null))),
-        };
+        let test_expr = Expr::Case(Case::new(
+            Some(Box::new(lit(1.0_f32))),
+            vec![(Box::new(lit(2.0_f32)), Box::new(lit(3.0_f32)))],
+            Some(Box::new(Expr::Literal(ScalarValue::Null))),
+        ));
 
         let ctx = SessionContext::new();
         roundtrip_expr_test(test_expr, ctx);
diff --git a/datafusion/proto/src/to_proto.rs b/datafusion/proto/src/to_proto.rs
index 47b779fff..7b70821ec 100644
--- a/datafusion/proto/src/to_proto.rs
+++ b/datafusion/proto/src/to_proto.rs
@@ -771,12 +771,8 @@ impl TryFrom<&Expr> for protobuf::LogicalExprNode {
                     expr_type: Some(ExprType::Between(expr)),
                 }
             }
-            Expr::Case {
-                expr,
-                when_then_expr,
-                else_expr,
-            } => {
-                let when_then_expr = when_then_expr
+            Expr::Case(case) => {
+                let when_then_expr = case.when_then_expr
                     .iter()
                     .map(|(w, t)| {
                         Ok(protobuf::WhenThen {
@@ -786,12 +782,12 @@ impl TryFrom<&Expr> for protobuf::LogicalExprNode {
                     })
                     .collect::<Result<Vec<protobuf::WhenThen>, Error>>()?;
                 let expr = Box::new(protobuf::CaseNode {
-                    expr: match expr {
+                    expr: match &case.expr {
                         Some(e) => Some(Box::new(e.as_ref().try_into()?)),
                         None => None,
                     },
                     when_then_expr,
-                    else_expr: match else_expr {
+                    else_expr: match &case.else_expr {
                         Some(e) => Some(Box::new(e.as_ref().try_into()?)),
                         None => None,
                     },
diff --git a/datafusion/sql/src/planner.rs b/datafusion/sql/src/planner.rs
index 9a4e29228..58b65af59 100644
--- a/datafusion/sql/src/planner.rs
+++ b/datafusion/sql/src/planner.rs
@@ -51,7 +51,7 @@ use crate::utils::{make_decimal_type, normalize_ident, resolve_columns};
 use datafusion_common::{
     field_not_found, Column, DFSchema, DFSchemaRef, DataFusionError, Result, ScalarValue,
 };
-use datafusion_expr::expr::GroupingSet;
+use datafusion_expr::expr::{Case, GroupingSet};
 use datafusion_expr::logical_plan::builder::project_with_alias;
 use datafusion_expr::logical_plan::{Filter, Subquery};
 use datafusion_expr::Expr::Alias;
@@ -1872,15 +1872,15 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
                     None
                 };
 
-                Ok(Expr::Case {
+                Ok(Expr::Case(Case::new(
                     expr,
-                    when_then_expr: when_expr
+                    when_expr
                         .iter()
                         .zip(then_expr.iter())
                         .map(|(w, t)| (Box::new(w.to_owned()), Box::new(t.to_owned())))
                         .collect(),
                     else_expr,
-                })
+                )))
             }
 
             SQLExpr::Cast {
diff --git a/datafusion/sql/src/utils.rs b/datafusion/sql/src/utils.rs
index eb58509d0..952ef3110 100644
--- a/datafusion/sql/src/utils.rs
+++ b/datafusion/sql/src/utils.rs
@@ -21,7 +21,7 @@ use arrow::datatypes::{DataType, DECIMAL128_MAX_PRECISION, DECIMAL_DEFAULT_SCALE
 use sqlparser::ast::Ident;
 
 use datafusion_common::{DataFusionError, Result, ScalarValue};
-use datafusion_expr::expr::GroupingSet;
+use datafusion_expr::expr::{Case, GroupingSet};
 use datafusion_expr::utils::{expr_as_column_expr, find_column_exprs};
 use datafusion_expr::{Expr, LogicalPlan};
 use std::collections::HashMap;
@@ -268,18 +268,14 @@ where
                 pattern: Box::new(clone_with_replacement(pattern, replacement_fn)?),
                 escape_char: *escape_char,
             }),
-            Expr::Case {
-                expr: case_expr_opt,
-                when_then_expr,
-                else_expr: else_expr_opt,
-            } => Ok(Expr::Case {
-                expr: match case_expr_opt {
+            Expr::Case(case) => Ok(Expr::Case(Case::new(
+                match &case.expr {
                     Some(case_expr) => {
                         Some(Box::new(clone_with_replacement(case_expr, replacement_fn)?))
                     }
                     None => None,
                 },
-                when_then_expr: when_then_expr
+                case.when_then_expr
                     .iter()
                     .map(|(a, b)| {
                         Ok((
@@ -288,13 +284,13 @@ where
                         ))
                     })
                     .collect::<Result<Vec<(_, _)>>>()?,
-                else_expr: match else_expr_opt {
+                match &case.else_expr {
                     Some(else_expr) => {
                         Some(Box::new(clone_with_replacement(else_expr, replacement_fn)?))
                     }
                     None => None,
                 },
-            }),
+            ))),
             Expr::ScalarFunction { fun, args } => Ok(Expr::ScalarFunction {
                 fun: fun.clone(),
                 args: args