You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by ag...@apache.org on 2022/10/08 00:18:15 UTC
[arrow-datafusion] branch master updated: Refactor `Expr::Case` to use a struct (#3757)
This is an automated email from the ASF dual-hosted git repository.
agrove pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/master by this push:
new 488b2cec3 Refactor `Expr::Case` to use a struct (#3757)
488b2cec3 is described below
commit 488b2cec3c700821dfdfece2d85c4cd7956e718d
Author: Andy Grove <an...@gmail.com>
AuthorDate: Fri Oct 7 18:18:09 2022 -0600
Refactor `Expr::Case` to use a struct (#3757)
---
datafusion/core/src/dataframe.rs | 1 -
datafusion/core/src/physical_plan/planner.rs | 12 ++--
datafusion/expr/src/conditional_expressions.rs | 12 ++--
datafusion/expr/src/expr.rs | 60 +++++++++++---------
datafusion/expr/src/expr_rewriter.rs | 21 +++----
datafusion/expr/src/expr_schema.rs | 13 ++---
datafusion/expr/src/expr_visitor.rs | 12 ++--
datafusion/optimizer/src/simplify_expressions.rs | 71 +++++++++++-------------
datafusion/optimizer/src/type_coercion.rs | 24 +++-----
datafusion/physical-expr/src/planner.rs | 36 ++++++------
datafusion/proto/src/from_proto.rs | 10 ++--
datafusion/proto/src/lib.rs | 22 ++++----
datafusion/proto/src/to_proto.rs | 12 ++--
datafusion/sql/src/planner.rs | 8 +--
datafusion/sql/src/utils.rs | 16 ++----
15 files changed, 150 insertions(+), 180 deletions(-)
diff --git a/datafusion/core/src/dataframe.rs b/datafusion/core/src/dataframe.rs
index 06768b563..a5caad176 100644
--- a/datafusion/core/src/dataframe.rs
+++ b/datafusion/core/src/dataframe.rs
@@ -827,7 +827,6 @@ impl TableProvider for DataFrame {
#[cfg(test)]
mod tests {
- use arrow::array::Int32Array;
use std::vec;
use super::*;
diff --git a/datafusion/core/src/physical_plan/planner.rs b/datafusion/core/src/physical_plan/planner.rs
index 86d005776..8bb1d95a4 100644
--- a/datafusion/core/src/physical_plan/planner.rs
+++ b/datafusion/core/src/physical_plan/planner.rs
@@ -111,19 +111,15 @@ fn create_physical_name(e: &Expr, is_first_expr: bool) -> Result<String> {
let right = create_physical_name(right, false)?;
Ok(format!("{} {} {}", left, op, right))
}
- Expr::Case {
- expr,
- when_then_expr,
- else_expr,
- } => {
+ Expr::Case(case) => {
let mut name = "CASE ".to_string();
- if let Some(e) = expr {
+ if let Some(e) = &case.expr {
let _ = write!(name, "{:?} ", e);
}
- for (w, t) in when_then_expr {
+ for (w, t) in &case.when_then_expr {
let _ = write!(name, "WHEN {:?} THEN {:?} ", w, t);
}
- if let Some(e) = else_expr {
+ if let Some(e) = &case.else_expr {
let _ = write!(name, "ELSE {:?} ", e);
}
name += "END";
diff --git a/datafusion/expr/src/conditional_expressions.rs b/datafusion/expr/src/conditional_expressions.rs
index 0c5104a4b..28ac6e8cd 100644
--- a/datafusion/expr/src/conditional_expressions.rs
+++ b/datafusion/expr/src/conditional_expressions.rs
@@ -15,6 +15,7 @@
// specific language governing permissions and limitations
// under the License.
+use crate::expr::Case;
///! Conditional expressions
use crate::{expr_schema::ExprSchemable, Expr};
use arrow::datatypes::DataType;
@@ -108,16 +109,15 @@ impl CaseBuilder {
}
}
- Ok(Expr::Case {
- expr: self.expr.clone(),
- when_then_expr: self
- .when_expr
+ Ok(Expr::Case(Case::new(
+ self.expr.clone(),
+ self.when_expr
.iter()
.zip(self.then_expr.iter())
.map(|(w, t)| (Box::new(w.clone()), Box::new(t.clone())))
.collect(),
- else_expr: self.else_expr.clone(),
- })
+ self.else_expr.clone(),
+ )))
}
}
diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs
index 008a2c454..c131682a8 100644
--- a/datafusion/expr/src/expr.rs
+++ b/datafusion/expr/src/expr.rs
@@ -176,14 +176,7 @@ pub enum Expr {
/// [WHEN ...]
/// [ELSE result]
/// END
- Case {
- /// Optional base expression that can be compared to literal values in the "when" expressions
- expr: Option<Box<Expr>>,
- /// One or more when/then expressions
- when_then_expr: Vec<(Box<Expr>, Box<Expr>)>,
- /// Optional "else" expression
- else_expr: Option<Box<Expr>>,
- },
+ Case(Case),
/// Casts the expression to a given type and will return a runtime error if the expression cannot be cast.
/// This expression is guaranteed to have a fixed type.
Cast {
@@ -292,6 +285,32 @@ pub enum Expr {
GroupingSet(GroupingSet),
}
+/// CASE expression
+#[derive(Clone, PartialEq, Eq, Hash)]
+pub struct Case {
+ /// Optional base expression that can be compared to literal values in the "when" expressions
+ pub expr: Option<Box<Expr>>,
+ /// One or more when/then expressions
+ pub when_then_expr: Vec<(Box<Expr>, Box<Expr>)>,
+ /// Optional "else" expression
+ pub else_expr: Option<Box<Expr>>,
+}
+
+impl Case {
+ /// Create a new Case expression
+ pub fn new(
+ expr: Option<Box<Expr>>,
+ when_then_expr: Vec<(Box<Expr>, Box<Expr>)>,
+ else_expr: Option<Box<Expr>>,
+ ) -> Self {
+ Self {
+ expr,
+ when_then_expr,
+ else_expr,
+ }
+ }
+}
+
/// Grouping sets
/// See https://www.postgresql.org/docs/current/queries-table-expressions.html#QUERIES-GROUPING-SETS
/// for Postgres definition.
@@ -601,20 +620,15 @@ impl fmt::Debug for Expr {
Expr::Column(c) => write!(f, "{}", c),
Expr::ScalarVariable(_, var_names) => write!(f, "{}", var_names.join(".")),
Expr::Literal(v) => write!(f, "{:?}", v),
- Expr::Case {
- expr,
- when_then_expr,
- else_expr,
- ..
- } => {
+ Expr::Case(case) => {
write!(f, "CASE ")?;
- if let Some(e) = expr {
+ if let Some(e) = &case.expr {
write!(f, "{:?} ", e)?;
}
- for (w, t) in when_then_expr {
+ for (w, t) in &case.when_then_expr {
write!(f, "WHEN {:?} THEN {:?} ", w, t)?;
}
- if let Some(e) = else_expr {
+ if let Some(e) = &case.else_expr {
write!(f, "ELSE {:?} ", e)?;
}
write!(f, "END")
@@ -957,22 +971,18 @@ fn create_name(e: &Expr) -> Result<String> {
);
Ok(s)
}
- Expr::Case {
- expr,
- when_then_expr,
- else_expr,
- } => {
+ Expr::Case(case) => {
let mut name = "CASE ".to_string();
- if let Some(e) = expr {
+ if let Some(e) = &case.expr {
let e = create_name(e)?;
let _ = write!(name, "{} ", e);
}
- for (w, t) in when_then_expr {
+ for (w, t) in &case.when_then_expr {
let when = create_name(w)?;
let then = create_name(t)?;
let _ = write!(name, "WHEN {} THEN {} ", when, then);
}
- if let Some(e) = else_expr {
+ if let Some(e) = &case.else_expr {
let e = create_name(e)?;
let _ = write!(name, "ELSE {} ", e);
}
diff --git a/datafusion/expr/src/expr_rewriter.rs b/datafusion/expr/src/expr_rewriter.rs
index 6bdb54522..427fcf170 100644
--- a/datafusion/expr/src/expr_rewriter.rs
+++ b/datafusion/expr/src/expr_rewriter.rs
@@ -17,7 +17,7 @@
//! Expression rewriter
-use crate::expr::GroupingSet;
+use crate::expr::{Case, GroupingSet};
use crate::logical_plan::{Aggregate, Projection};
use crate::utils::{from_plan, grouping_set_to_exprlist};
use crate::{Expr, ExprSchemable, LogicalPlan};
@@ -184,13 +184,10 @@ impl ExprRewritable for Expr {
high: rewrite_boxed(high, rewriter)?,
negated,
},
- Expr::Case {
- expr,
- when_then_expr,
- else_expr,
- } => {
- let expr = rewrite_option_box(expr, rewriter)?;
- let when_then_expr = when_then_expr
+ Expr::Case(case) => {
+ let expr = rewrite_option_box(case.expr, rewriter)?;
+ let when_then_expr = case
+ .when_then_expr
.into_iter()
.map(|(when, then)| {
Ok((
@@ -200,13 +197,9 @@ impl ExprRewritable for Expr {
})
.collect::<Result<Vec<_>>>()?;
- let else_expr = rewrite_option_box(else_expr, rewriter)?;
+ let else_expr = rewrite_option_box(case.else_expr, rewriter)?;
- Expr::Case {
- expr,
- when_then_expr,
- else_expr,
- }
+ Expr::Case(Case::new(expr, when_then_expr, else_expr))
}
Expr::Cast { expr, data_type } => Expr::Cast {
expr: rewrite_boxed(expr, rewriter)?,
diff --git a/datafusion/expr/src/expr_schema.rs b/datafusion/expr/src/expr_schema.rs
index 88d767366..5442a2421 100644
--- a/datafusion/expr/src/expr_schema.rs
+++ b/datafusion/expr/src/expr_schema.rs
@@ -59,7 +59,7 @@ impl ExprSchemable for Expr {
Expr::Column(c) => Ok(schema.data_type(c)?.clone()),
Expr::ScalarVariable(ty, _) => Ok(ty.clone()),
Expr::Literal(l) => Ok(l.get_datatype()),
- Expr::Case { when_then_expr, .. } => when_then_expr[0].1.get_type(schema),
+ Expr::Case(case) => case.when_then_expr[0].1.get_type(schema),
Expr::Cast { data_type, .. } | Expr::TryCast { data_type, .. } => {
Ok(data_type.clone())
}
@@ -164,19 +164,16 @@ impl ExprSchemable for Expr {
| Expr::InList { expr, .. } => expr.nullable(input_schema),
Expr::Column(c) => input_schema.nullable(c),
Expr::Literal(value) => Ok(value.is_null()),
- Expr::Case {
- when_then_expr,
- else_expr,
- ..
- } => {
+ Expr::Case(case) => {
// this expression is nullable if any of the input expressions are nullable
- let then_nullable = when_then_expr
+ let then_nullable = case
+ .when_then_expr
.iter()
.map(|(_, t)| t.nullable(input_schema))
.collect::<Result<Vec<_>>>()?;
if then_nullable.contains(&true) {
Ok(true)
- } else if let Some(e) = else_expr {
+ } else if let Some(e) = &case.else_expr {
e.nullable(input_schema)
} else {
// CASE produces NULL if there is no `else` expr
diff --git a/datafusion/expr/src/expr_visitor.rs b/datafusion/expr/src/expr_visitor.rs
index 3885456cc..f362a759e 100644
--- a/datafusion/expr/src/expr_visitor.rs
+++ b/datafusion/expr/src/expr_visitor.rs
@@ -153,24 +153,20 @@ impl ExprVisitable for Expr {
let visitor = low.accept(visitor)?;
high.accept(visitor)
}
- Expr::Case {
- expr,
- when_then_expr,
- else_expr,
- } => {
- let visitor = if let Some(expr) = expr.as_ref() {
+ Expr::Case(case) => {
+ let visitor = if let Some(expr) = case.expr.as_ref() {
expr.accept(visitor)
} else {
Ok(visitor)
}?;
- let visitor = when_then_expr.iter().try_fold(
+ let visitor = case.when_then_expr.iter().try_fold(
visitor,
|visitor, (when, then)| {
let visitor = when.accept(visitor)?;
then.accept(visitor)
},
)?;
- if let Some(else_expr) = else_expr.as_ref() {
+ if let Some(else_expr) = case.else_expr.as_ref() {
else_expr.accept(visitor)
} else {
Ok(visitor)
diff --git a/datafusion/optimizer/src/simplify_expressions.rs b/datafusion/optimizer/src/simplify_expressions.rs
index bb17a9925..c96f6eea7 100644
--- a/datafusion/optimizer/src/simplify_expressions.rs
+++ b/datafusion/optimizer/src/simplify_expressions.rs
@@ -490,7 +490,7 @@ impl<'a> ConstEvaluator<'a> {
| Expr::Like { .. }
| Expr::ILike { .. }
| Expr::SimilarTo { .. }
- | Expr::Case { .. }
+ | Expr::Case(_)
| Expr::Cast { .. }
| Expr::TryCast { .. }
| Expr::InList { .. }
@@ -848,20 +848,17 @@ impl<'a, S: SimplifyInfo> ExprRewriter for Simplifier<'a, S> {
//
// Note: the rationale for this rewrite is that the expr can then be further
// simplified using the existing rules for AND/OR
- Case {
- expr: None,
- when_then_expr,
- else_expr,
- } if !when_then_expr.is_empty()
- && when_then_expr.len() < 3 // The rewrite is O(n!) so limit to small number
- && info.is_boolean_type(&when_then_expr[0].1)? =>
+ Case(case)
+ if !case.when_then_expr.is_empty()
+ && case.when_then_expr.len() < 3 // The rewrite is O(n!) so limit to small number
+ && info.is_boolean_type(&case.when_then_expr[0].1)? =>
{
// The disjunction of all the when predicates encountered so far
let mut filter_expr = lit(false);
// The disjunction of all the cases
let mut out_expr = lit(false);
- for (when, then) in when_then_expr {
+ for (when, then) in case.when_then_expr {
let case_expr = when
.as_ref()
.clone()
@@ -872,7 +869,7 @@ impl<'a, S: SimplifyInfo> ExprRewriter for Simplifier<'a, S> {
filter_expr = filter_expr.or(*when);
}
- if let Some(else_expr) = else_expr {
+ if let Some(else_expr) = case.else_expr {
let case_expr = filter_expr.not().and(*else_expr);
out_expr = out_expr.or(case_expr);
}
@@ -974,6 +971,7 @@ mod tests {
use arrow::array::{ArrayRef, Int32Array};
use chrono::{DateTime, TimeZone, Utc};
use datafusion_common::{DFField, ToDFSchema};
+ use datafusion_expr::expr::Case;
use datafusion_expr::logical_plan::table_scan;
use datafusion_expr::{
and, binary_expr, call_fn, col, create_udf, lit, lit_timestamp_nano,
@@ -1700,14 +1698,14 @@ mod tests {
// -->
// false
assert_eq!(
- simplify(Expr::Case {
- expr: None,
- when_then_expr: vec![(
+ simplify(Expr::Case(Case::new(
+ None,
+ vec![(
Box::new(col("c2").not_eq(lit(false))),
Box::new(lit("ok").eq(lit("not_ok"))),
)],
- else_expr: Some(Box::new(col("c2").eq(lit(true)))),
- }),
+ Some(Box::new(col("c2").eq(lit(true)))),
+ ))),
col("c2").not().and(col("c2")) // #1716
);
@@ -1720,14 +1718,14 @@ mod tests {
// Need to call simplify 2x due to
// https://github.com/apache/arrow-datafusion/issues/1160
assert_eq!(
- simplify(simplify(Expr::Case {
- expr: None,
- when_then_expr: vec![(
+ simplify(simplify(Expr::Case(Case::new(
+ None,
+ vec![(
Box::new(col("c2").not_eq(lit(false))),
Box::new(lit("ok").eq(lit("ok"))),
)],
- else_expr: Some(Box::new(col("c2").eq(lit(true)))),
- })),
+ Some(Box::new(col("c2").eq(lit(true)))),
+ )))),
col("c2").or(col("c2").not().and(col("c2"))) // #1716
);
@@ -1738,14 +1736,11 @@ mod tests {
// Need to call simplify 2x due to
// https://github.com/apache/arrow-datafusion/issues/1160
assert_eq!(
- simplify(simplify(Expr::Case {
- expr: None,
- when_then_expr: vec![(
- Box::new(col("c2").is_null()),
- Box::new(lit(true)),
- )],
- else_expr: Some(Box::new(col("c2"))),
- })),
+ simplify(simplify(Expr::Case(Case::new(
+ None,
+ vec![(Box::new(col("c2").is_null()), Box::new(lit(true)),)],
+ Some(Box::new(col("c2"))),
+ )))),
col("c2")
.is_null()
.or(col("c2").is_not_null().and(col("c2")))
@@ -1759,14 +1754,14 @@ mod tests {
// Need to call simplify 2x due to
// https://github.com/apache/arrow-datafusion/issues/1160
assert_eq!(
- simplify(simplify(Expr::Case {
- expr: None,
- when_then_expr: vec![
+ simplify(simplify(Expr::Case(Case::new(
+ None,
+ vec![
(Box::new(col("c1")), Box::new(lit(true)),),
(Box::new(col("c2")), Box::new(lit(false)),),
],
- else_expr: Some(Box::new(lit(true))),
- })),
+ Some(Box::new(lit(true))),
+ )))),
col("c1").or(col("c1").not().and(col("c2").not()))
);
@@ -1778,14 +1773,14 @@ mod tests {
// Need to call simplify 2x due to
// https://github.com/apache/arrow-datafusion/issues/1160
assert_eq!(
- simplify(simplify(Expr::Case {
- expr: None,
- when_then_expr: vec![
+ simplify(simplify(Expr::Case(Case::new(
+ None,
+ vec![
(Box::new(col("c1")), Box::new(lit(true)),),
(Box::new(col("c2")), Box::new(lit(false)),),
],
- else_expr: Some(Box::new(lit(true))),
- })),
+ Some(Box::new(lit(true))),
+ )))),
col("c1").or(col("c1").not().and(col("c2").not()))
);
}
diff --git a/datafusion/optimizer/src/type_coercion.rs b/datafusion/optimizer/src/type_coercion.rs
index f0470da87..fcfe6eaaa 100644
--- a/datafusion/optimizer/src/type_coercion.rs
+++ b/datafusion/optimizer/src/type_coercion.rs
@@ -20,6 +20,7 @@
use crate::{OptimizerConfig, OptimizerRule};
use arrow::datatypes::DataType;
use datafusion_common::{DFSchema, DFSchemaRef, DataFusionError, Result};
+use datafusion_expr::expr::Case;
use datafusion_expr::expr_rewriter::{ExprRewritable, ExprRewriter, RewriteRecursion};
use datafusion_expr::logical_plan::Subquery;
use datafusion_expr::type_coercion::binary::{coerce_types, comparison_coercion};
@@ -357,18 +358,15 @@ impl ExprRewriter for TypeCoercionRewriter {
}
}
}
- Expr::Case {
- expr,
- when_then_expr,
- else_expr,
- } => {
+ Expr::Case(case) => {
// all the result of then and else should be convert to a common data type,
// if they can be coercible to a common data type, return error.
- let then_types = when_then_expr
+ let then_types = case
+ .when_then_expr
.iter()
.map(|when_then| when_then.1.get_type(&self.schema))
.collect::<Result<Vec<_>>>()?;
- let else_type = match &else_expr {
+ let else_type = match &case.else_expr {
None => Ok(None),
Some(expr) => expr.get_type(&self.schema).map(Some),
}?;
@@ -380,24 +378,20 @@ impl ExprRewriter for TypeCoercionRewriter {
then_types, else_type
))),
Some(data_type) => {
- let left = when_then_expr
+ let left = case.when_then_expr
.into_iter()
.map(|(when, then)| {
let then = then.cast_to(&data_type, &self.schema)?;
Ok((when, Box::new(then)))
})
.collect::<Result<Vec<_>>>()?;
- let right = match else_expr {
+ let right = match &case.else_expr {
None => None,
Some(expr) => {
- Some(Box::new(expr.cast_to(&data_type, &self.schema)?))
+ Some(Box::new(expr.clone().cast_to(&data_type, &self.schema)?))
}
};
- Ok(Expr::Case {
- expr,
- when_then_expr: left,
- else_expr: right,
- })
+ Ok(Expr::Case(Case::new(case.expr,left,right)))
}
}
}
diff --git a/datafusion/physical-expr/src/planner.rs b/datafusion/physical-expr/src/planner.rs
index 0964d6480..993891884 100644
--- a/datafusion/physical-expr/src/planner.rs
+++ b/datafusion/physical-expr/src/planner.rs
@@ -221,13 +221,8 @@ pub fn create_physical_expr(
binary_expr(expr.as_ref().clone(), op, pattern.as_ref().clone());
create_physical_expr(&bin_expr, input_dfschema, input_schema, execution_props)
}
- Expr::Case {
- expr,
- when_then_expr,
- else_expr,
- ..
- } => {
- let expr: Option<Arc<dyn PhysicalExpr>> = if let Some(e) = expr {
+ Expr::Case(case) => {
+ let expr: Option<Arc<dyn PhysicalExpr>> = if let Some(e) = &case.expr {
Some(create_physical_expr(
e.as_ref(),
input_dfschema,
@@ -237,7 +232,8 @@ pub fn create_physical_expr(
} else {
None
};
- let when_expr = when_then_expr
+ let when_expr = case
+ .when_then_expr
.iter()
.map(|(w, _)| {
create_physical_expr(
@@ -248,7 +244,8 @@ pub fn create_physical_expr(
)
})
.collect::<Result<Vec<_>>>()?;
- let then_expr = when_then_expr
+ let then_expr = case
+ .when_then_expr
.iter()
.map(|(_, t)| {
create_physical_expr(
@@ -265,16 +262,17 @@ pub fn create_physical_expr(
.zip(then_expr.iter())
.map(|(w, t)| (w.clone(), t.clone()))
.collect();
- let else_expr: Option<Arc<dyn PhysicalExpr>> = if let Some(e) = else_expr {
- Some(create_physical_expr(
- e.as_ref(),
- input_dfschema,
- input_schema,
- execution_props,
- )?)
- } else {
- None
- };
+ let else_expr: Option<Arc<dyn PhysicalExpr>> =
+ if let Some(e) = &case.else_expr {
+ Some(create_physical_expr(
+ e.as_ref(),
+ input_dfschema,
+ input_schema,
+ execution_props,
+ )?)
+ } else {
+ None
+ };
Ok(expressions::case(expr, when_then_expr, else_expr)?)
}
Expr::Cast { expr, data_type } => expressions::cast(
diff --git a/datafusion/proto/src/from_proto.rs b/datafusion/proto/src/from_proto.rs
index 3eeb30edf..208c24036 100644
--- a/datafusion/proto/src/from_proto.rs
+++ b/datafusion/proto/src/from_proto.rs
@@ -31,8 +31,8 @@ use datafusion::logical_plan::FunctionRegistry;
use datafusion_common::{
Column, DFField, DFSchema, DFSchemaRef, DataFusionError, ScalarValue,
};
-use datafusion_expr::expr::GroupingSet;
use datafusion_expr::expr::GroupingSet::GroupingSets;
+use datafusion_expr::expr::{Case, GroupingSet};
use datafusion_expr::{
abs, acos, array, ascii, asin, atan, atan2, bit_length, btrim, ceil,
character_length, chr, coalesce, concat_expr, concat_ws_expr, cos, date_bin,
@@ -1023,11 +1023,11 @@ pub fn parse_expr(
Ok((Box::new(when_expr), Box::new(then_expr)))
})
.collect::<Result<Vec<(Box<Expr>, Box<Expr>)>, Error>>()?;
- Ok(Expr::Case {
- expr: parse_optional_expr(&case.expr, registry)?.map(Box::new),
+ Ok(Expr::Case(Case::new(
+ parse_optional_expr(&case.expr, registry)?.map(Box::new),
when_then_expr,
- else_expr: parse_optional_expr(&case.else_expr, registry)?.map(Box::new),
- })
+ parse_optional_expr(&case.else_expr, registry)?.map(Box::new),
+ )))
}
ExprType::Cast(cast) => {
let expr = Box::new(parse_required_expr(&cast.expr, registry, "expr")?);
diff --git a/datafusion/proto/src/lib.rs b/datafusion/proto/src/lib.rs
index e3b6c848a..7a495077f 100644
--- a/datafusion/proto/src/lib.rs
+++ b/datafusion/proto/src/lib.rs
@@ -63,7 +63,7 @@ mod roundtrip_tests {
use datafusion::physical_plan::functions::make_scalar_function;
use datafusion::prelude::{create_udf, CsvReadOptions, SessionContext};
use datafusion_common::{DFSchemaRef, DataFusionError, ScalarValue};
- use datafusion_expr::expr::GroupingSet;
+ use datafusion_expr::expr::{Case, GroupingSet};
use datafusion_expr::logical_plan::{Extension, UserDefinedLogicalNode};
use datafusion_expr::{
col, lit, Accumulator, AggregateFunction, AggregateState,
@@ -970,11 +970,11 @@ mod roundtrip_tests {
#[test]
fn roundtrip_case() {
- let test_expr = Expr::Case {
- expr: Some(Box::new(lit(1.0_f32))),
- when_then_expr: vec![(Box::new(lit(2.0_f32)), Box::new(lit(3.0_f32)))],
- else_expr: Some(Box::new(lit(4.0_f32))),
- };
+ let test_expr = Expr::Case(Case::new(
+ Some(Box::new(lit(1.0_f32))),
+ vec![(Box::new(lit(2.0_f32)), Box::new(lit(3.0_f32)))],
+ Some(Box::new(lit(4.0_f32))),
+ ));
let ctx = SessionContext::new();
roundtrip_expr_test(test_expr, ctx);
@@ -982,11 +982,11 @@ mod roundtrip_tests {
#[test]
fn roundtrip_case_with_null() {
- let test_expr = Expr::Case {
- expr: Some(Box::new(lit(1.0_f32))),
- when_then_expr: vec![(Box::new(lit(2.0_f32)), Box::new(lit(3.0_f32)))],
- else_expr: Some(Box::new(Expr::Literal(ScalarValue::Null))),
- };
+ let test_expr = Expr::Case(Case::new(
+ Some(Box::new(lit(1.0_f32))),
+ vec![(Box::new(lit(2.0_f32)), Box::new(lit(3.0_f32)))],
+ Some(Box::new(Expr::Literal(ScalarValue::Null))),
+ ));
let ctx = SessionContext::new();
roundtrip_expr_test(test_expr, ctx);
diff --git a/datafusion/proto/src/to_proto.rs b/datafusion/proto/src/to_proto.rs
index 47b779fff..7b70821ec 100644
--- a/datafusion/proto/src/to_proto.rs
+++ b/datafusion/proto/src/to_proto.rs
@@ -771,12 +771,8 @@ impl TryFrom<&Expr> for protobuf::LogicalExprNode {
expr_type: Some(ExprType::Between(expr)),
}
}
- Expr::Case {
- expr,
- when_then_expr,
- else_expr,
- } => {
- let when_then_expr = when_then_expr
+ Expr::Case(case) => {
+ let when_then_expr = case.when_then_expr
.iter()
.map(|(w, t)| {
Ok(protobuf::WhenThen {
@@ -786,12 +782,12 @@ impl TryFrom<&Expr> for protobuf::LogicalExprNode {
})
.collect::<Result<Vec<protobuf::WhenThen>, Error>>()?;
let expr = Box::new(protobuf::CaseNode {
- expr: match expr {
+ expr: match &case.expr {
Some(e) => Some(Box::new(e.as_ref().try_into()?)),
None => None,
},
when_then_expr,
- else_expr: match else_expr {
+ else_expr: match &case.else_expr {
Some(e) => Some(Box::new(e.as_ref().try_into()?)),
None => None,
},
diff --git a/datafusion/sql/src/planner.rs b/datafusion/sql/src/planner.rs
index 9a4e29228..58b65af59 100644
--- a/datafusion/sql/src/planner.rs
+++ b/datafusion/sql/src/planner.rs
@@ -51,7 +51,7 @@ use crate::utils::{make_decimal_type, normalize_ident, resolve_columns};
use datafusion_common::{
field_not_found, Column, DFSchema, DFSchemaRef, DataFusionError, Result, ScalarValue,
};
-use datafusion_expr::expr::GroupingSet;
+use datafusion_expr::expr::{Case, GroupingSet};
use datafusion_expr::logical_plan::builder::project_with_alias;
use datafusion_expr::logical_plan::{Filter, Subquery};
use datafusion_expr::Expr::Alias;
@@ -1872,15 +1872,15 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
None
};
- Ok(Expr::Case {
+ Ok(Expr::Case(Case::new(
expr,
- when_then_expr: when_expr
+ when_expr
.iter()
.zip(then_expr.iter())
.map(|(w, t)| (Box::new(w.to_owned()), Box::new(t.to_owned())))
.collect(),
else_expr,
- })
+ )))
}
SQLExpr::Cast {
diff --git a/datafusion/sql/src/utils.rs b/datafusion/sql/src/utils.rs
index eb58509d0..952ef3110 100644
--- a/datafusion/sql/src/utils.rs
+++ b/datafusion/sql/src/utils.rs
@@ -21,7 +21,7 @@ use arrow::datatypes::{DataType, DECIMAL128_MAX_PRECISION, DECIMAL_DEFAULT_SCALE
use sqlparser::ast::Ident;
use datafusion_common::{DataFusionError, Result, ScalarValue};
-use datafusion_expr::expr::GroupingSet;
+use datafusion_expr::expr::{Case, GroupingSet};
use datafusion_expr::utils::{expr_as_column_expr, find_column_exprs};
use datafusion_expr::{Expr, LogicalPlan};
use std::collections::HashMap;
@@ -268,18 +268,14 @@ where
pattern: Box::new(clone_with_replacement(pattern, replacement_fn)?),
escape_char: *escape_char,
}),
- Expr::Case {
- expr: case_expr_opt,
- when_then_expr,
- else_expr: else_expr_opt,
- } => Ok(Expr::Case {
- expr: match case_expr_opt {
+ Expr::Case(case) => Ok(Expr::Case(Case::new(
+ match &case.expr {
Some(case_expr) => {
Some(Box::new(clone_with_replacement(case_expr, replacement_fn)?))
}
None => None,
},
- when_then_expr: when_then_expr
+ case.when_then_expr
.iter()
.map(|(a, b)| {
Ok((
@@ -288,13 +284,13 @@ where
))
})
.collect::<Result<Vec<(_, _)>>>()?,
- else_expr: match else_expr_opt {
+ match &case.else_expr {
Some(else_expr) => {
Some(Box::new(clone_with_replacement(else_expr, replacement_fn)?))
}
None => None,
},
- }),
+ ))),
Expr::ScalarFunction { fun, args } => Ok(Expr::ScalarFunction {
fun: fun.clone(),
args: args