You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by li...@apache.org on 2022/10/04 09:38:24 UTC
[arrow-datafusion] branch master updated: move `type coercion` for case when expr (#3676)
This is an automated email from the ASF dual-hosted git repository.
liukun pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/master by this push:
new 01117325d move `type coercion` for case when expr (#3676)
01117325d is described below
commit 01117325da728f094f235a77b4acec9a7c9bc7e9
Author: Kun Liu <li...@apache.org>
AuthorDate: Tue Oct 4 17:38:18 2022 +0800
move `type coercion` for case when expr (#3676)
* support type coercion in logical phase and remove it in the physical phase
* Apply suggestions from code review
Co-authored-by: Andrew Lamb <an...@nerdnetworks.org>
* format code
* change error message
Co-authored-by: Andrew Lamb <an...@nerdnetworks.org>
---
datafusion/core/tests/sql/expr.rs | 52 ++++----
datafusion/core/tests/sql/projection.rs | 14 +-
datafusion/optimizer/src/type_coercion.rs | 66 +++++++++
datafusion/physical-expr/src/expressions/case.rs | 162 +++++++++++++----------
datafusion/physical-expr/src/planner.rs | 7 +-
5 files changed, 194 insertions(+), 107 deletions(-)
diff --git a/datafusion/core/tests/sql/expr.rs b/datafusion/core/tests/sql/expr.rs
index b1e8227e0..93e453101 100644
--- a/datafusion/core/tests/sql/expr.rs
+++ b/datafusion/core/tests/sql/expr.rs
@@ -152,12 +152,12 @@ async fn case_expr_with_null() -> Result<()> {
let actual = execute_to_batches(&ctx, sql).await;
let expected = vec![
- "+------------------------------------------------+",
- "| CASE WHEN #a.b IS NULL THEN NULL ELSE #a.b END |",
- "+------------------------------------------------+",
- "| |",
- "| 3 |",
- "+------------------------------------------------+",
+ "+----------------------------------------------+",
+ "| CASE WHEN a.b IS NULL THEN NULL ELSE a.b END |",
+ "+----------------------------------------------+",
+ "| |",
+ "| 3 |",
+ "+----------------------------------------------+",
];
assert_batches_eq!(expected, &actual);
@@ -165,12 +165,12 @@ async fn case_expr_with_null() -> Result<()> {
let actual = execute_to_batches(&ctx, sql).await;
let expected = vec![
- "+------------------------------------------------+",
- "| CASE WHEN #a.b IS NULL THEN NULL ELSE #a.b END |",
- "+------------------------------------------------+",
- "| 1 |",
- "| 3 |",
- "+------------------------------------------------+",
+ "+----------------------------------------------+",
+ "| CASE WHEN a.b IS NULL THEN NULL ELSE a.b END |",
+ "+----------------------------------------------+",
+ "| 1 |",
+ "| 3 |",
+ "+----------------------------------------------+",
];
assert_batches_eq!(expected, &actual);
@@ -184,13 +184,13 @@ async fn case_expr_with_nulls() -> Result<()> {
let actual = execute_to_batches(&ctx, sql).await;
let expected = vec![
- "+--------------------------------------------------------------------------------------------------------------------------+",
- "| CASE WHEN #a.b IS NULL THEN NULL WHEN #a.b < Int64(3) THEN NULL WHEN #a.b >= Int64(3) THEN #a.b + Int64(1) ELSE #a.b END |",
- "+--------------------------------------------------------------------------------------------------------------------------+",
- "| |",
- "| |",
- "| 4 |",
- "+--------------------------------------------------------------------------------------------------------------------------+"
+ "+---------------------------------------------------------------------------------------------------------------------+",
+ "| CASE WHEN a.b IS NULL THEN NULL WHEN a.b < Int64(3) THEN NULL WHEN a.b >= Int64(3) THEN a.b + Int64(1) ELSE a.b END |",
+ "+---------------------------------------------------------------------------------------------------------------------+",
+ "| |",
+ "| |",
+ "| 4 |",
+ "+---------------------------------------------------------------------------------------------------------------------+",
];
assert_batches_eq!(expected, &actual);
@@ -198,13 +198,13 @@ async fn case_expr_with_nulls() -> Result<()> {
let actual = execute_to_batches(&ctx, sql).await;
let expected = vec![
- "+------------------------------------------------------------------------------------------------------------+",
- "| CASE #a.b WHEN Int64(1) THEN NULL WHEN Int64(2) THEN NULL WHEN Int64(3) THEN #a.b + Int64(1) ELSE #a.b END |",
- "+------------------------------------------------------------------------------------------------------------+",
- "| |",
- "| |",
- "| 4 |",
- "+------------------------------------------------------------------------------------------------------------+",
+ "+---------------------------------------------------------------------------------------------------------+",
+ "| CASE a.b WHEN Int64(1) THEN NULL WHEN Int64(2) THEN NULL WHEN Int64(3) THEN a.b + Int64(1) ELSE a.b END |",
+ "+---------------------------------------------------------------------------------------------------------+",
+ "| |",
+ "| |",
+ "| 4 |",
+ "+---------------------------------------------------------------------------------------------------------+",
];
assert_batches_eq!(expected, &actual);
diff --git a/datafusion/core/tests/sql/projection.rs b/datafusion/core/tests/sql/projection.rs
index 97c6dcf8a..cee8e706c 100644
--- a/datafusion/core/tests/sql/projection.rs
+++ b/datafusion/core/tests/sql/projection.rs
@@ -252,13 +252,13 @@ async fn project_cast_dictionary() {
let actual = collect(physical_plan, ctx.task_ctx()).await.unwrap();
let expected = vec![
- "+------------------------------------------------------------------------------------+",
- "| CASE WHEN #cpu_load_short.host IS NULL THEN Utf8(\"\") ELSE #cpu_load_short.host END |",
- "+------------------------------------------------------------------------------------+",
- "| host1 |",
- "| |",
- "| host2 |",
- "+------------------------------------------------------------------------------------+",
+ "+----------------------------------------------------------------------------------+",
+ "| CASE WHEN cpu_load_short.host IS NULL THEN Utf8(\"\") ELSE cpu_load_short.host END |",
+ "+----------------------------------------------------------------------------------+",
+ "| host1 |",
+ "| |",
+ "| host2 |",
+ "+----------------------------------------------------------------------------------+",
];
assert_batches_eq!(expected, &actual);
}
diff --git a/datafusion/optimizer/src/type_coercion.rs b/datafusion/optimizer/src/type_coercion.rs
index 5a53adc26..c7f107b5e 100644
--- a/datafusion/optimizer/src/type_coercion.rs
+++ b/datafusion/optimizer/src/type_coercion.rs
@@ -354,6 +354,50 @@ impl ExprRewriter for TypeCoercionRewriter {
}
}
}
+ Expr::Case {
+ expr,
+ when_then_expr,
+ else_expr,
+ } => {
+ // all the result of then and else should be convert to a common data type,
+ // if they can be coercible to a common data type, return error.
+ let then_types = when_then_expr
+ .iter()
+ .map(|when_then| when_then.1.get_type(&self.schema))
+ .collect::<Result<Vec<_>>>()?;
+ let else_type = match &else_expr {
+ None => Ok(None),
+ Some(expr) => expr.get_type(&self.schema).map(Some),
+ }?;
+ let case_when_coerce_type =
+ get_coerce_type_for_case_when(&then_types, &else_type);
+ match case_when_coerce_type {
+ None => Err(DataFusionError::Internal(format!(
+ "Failed to coerce then ({:?}) and else ({:?}) to common types in CASE WHEN expression",
+ then_types, else_type
+ ))),
+ Some(data_type) => {
+ let left = when_then_expr
+ .into_iter()
+ .map(|(when, then)| {
+ let then = then.cast_to(&data_type, &self.schema)?;
+ Ok((when, Box::new(then)))
+ })
+ .collect::<Result<Vec<_>>>()?;
+ let right = match else_expr {
+ None => None,
+ Some(expr) => {
+ Some(Box::new(expr.cast_to(&data_type, &self.schema)?))
+ }
+ };
+ Ok(Expr::Case {
+ expr,
+ when_then_expr: left,
+ else_expr: right,
+ })
+ }
+ }
+ }
expr => Ok(expr),
}
}
@@ -410,6 +454,28 @@ fn coerce_arguments_for_signature(
.collect::<Result<Vec<_>>>()
}
+/// Find a common coerceable type for all `then_types` as well
+/// and the `else_type`, if specified.
+/// Returns the common data type for `then_types` and `else_type`
+fn get_coerce_type_for_case_when(
+ then_types: &[DataType],
+ else_type: &Option<DataType>,
+) -> Option<DataType> {
+ let else_type = match else_type {
+ None => then_types[0].clone(),
+ Some(data_type) => data_type.clone(),
+ };
+ then_types
+ .iter()
+ .fold(Some(else_type), |left, right_type| match left {
+ // failed to find a valid coercion in a previous iteration
+ None => None,
+ // TODO: now just use the `equal` coercion rule for case when. If find the issue, and
+ // refactor again.
+ Some(left_type) => comparison_coercion(&left_type, right_type),
+ })
+}
+
#[cfg(test)]
mod test {
use crate::type_coercion::{TypeCoercion, TypeCoercionRewriter};
diff --git a/datafusion/physical-expr/src/expressions/case.rs b/datafusion/physical-expr/src/expressions/case.rs
index cf4f7defe..b1bd0a604 100644
--- a/datafusion/physical-expr/src/expressions/case.rs
+++ b/datafusion/physical-expr/src/expressions/case.rs
@@ -25,7 +25,6 @@ use arrow::compute::{and, eq_dyn, is_null, not, or, or_kleene};
use arrow::datatypes::{DataType, Schema};
use arrow::record_batch::RecordBatch;
use datafusion_common::{DataFusionError, Result};
-use datafusion_expr::binary_rule::comparison_coercion;
use datafusion_expr::ColumnarValue;
type WhenThen = (Arc<dyn PhysicalExpr>, Arc<dyn PhysicalExpr>);
@@ -294,66 +293,10 @@ pub fn case(
expr: Option<Arc<dyn PhysicalExpr>>,
when_thens: Vec<WhenThen>,
else_expr: Option<Arc<dyn PhysicalExpr>>,
- input_schema: &Schema,
) -> Result<Arc<dyn PhysicalExpr>> {
- // all the result of then and else should be convert to a common data type,
- // if they can be coercible to a common data type, return error.
- let coerce_type = get_case_common_type(&when_thens, else_expr.clone(), input_schema);
- let (when_thens, else_expr) = match coerce_type {
- None => Err(DataFusionError::Plan(format!(
- "Can't get a common type for then {:?} and else {:?} expression",
- when_thens, else_expr
- ))),
- Some(data_type) => {
- // cast then expr
- let left = when_thens
- .into_iter()
- .map(|(when, then)| {
- let then = try_cast(then, input_schema, data_type.clone())?;
- Ok((when, then))
- })
- .collect::<Result<Vec<_>>>()?;
- let right = match else_expr {
- None => None,
- Some(expr) => Some(try_cast(expr, input_schema, data_type.clone())?),
- };
-
- Ok((left, right))
- }
- }?;
-
Ok(Arc::new(CaseExpr::try_new(expr, when_thens, else_expr)?))
}
-fn get_case_common_type(
- when_thens: &[WhenThen],
- else_expr: Option<Arc<dyn PhysicalExpr>>,
- input_schema: &Schema,
-) -> Option<DataType> {
- let thens_type = when_thens
- .iter()
- .map(|when_then| {
- let data_type = &when_then.1.data_type(input_schema).unwrap();
- data_type.clone()
- })
- .collect::<Vec<_>>();
- let else_type = match else_expr {
- None => {
- // case when then exprs must have one then value
- thens_type[0].clone()
- }
- Some(else_phy_expr) => else_phy_expr.data_type(input_schema).unwrap(),
- };
- thens_type
- .iter()
- .fold(Some(else_type), |left, right_type| match left {
- None => None,
- // TODO: now just use the `equal` coercion rule for case when. If find the issue, and
- // refactor again.
- Some(left_type) => comparison_coercion(&left_type, right_type),
- })
-}
-
#[cfg(test)]
mod tests {
use super::*;
@@ -365,6 +308,7 @@ mod tests {
use arrow::datatypes::DataType::Float64;
use arrow::datatypes::*;
use datafusion_common::ScalarValue;
+ use datafusion_expr::binary_rule::comparison_coercion;
use datafusion_expr::Operator;
#[test]
@@ -378,7 +322,7 @@ mod tests {
let when2 = lit("bar");
let then2 = lit(456i32);
- let expr = case(
+ let expr = generate_case_when_with_type_coercion(
Some(col("a", &schema)?),
vec![(when1, then1), (when2, then2)],
None,
@@ -409,7 +353,7 @@ mod tests {
let then2 = lit(456i32);
let else_value = lit(999i32);
- let expr = case(
+ let expr = generate_case_when_with_type_coercion(
Some(col("a", &schema)?),
vec![(when1, then1), (when2, then2)],
Some(else_value),
@@ -444,7 +388,7 @@ mod tests {
&batch.schema(),
)?;
- let expr = case(
+ let expr = generate_case_when_with_type_coercion(
Some(col("a", &schema)?),
vec![(when1, then1)],
Some(else_value),
@@ -484,7 +428,7 @@ mod tests {
)?;
let then2 = lit(456i32);
- let expr = case(
+ let expr = generate_case_when_with_type_coercion(
None,
vec![(when1, then1), (when2, then2)],
None,
@@ -518,7 +462,12 @@ mod tests {
)?;
let x = lit(ScalarValue::Float64(None));
- let expr = case(None, vec![(when1, then1)], Some(x), schema.as_ref())?;
+ let expr = generate_case_when_with_type_coercion(
+ None,
+ vec![(when1, then1)],
+ Some(x),
+ schema.as_ref(),
+ )?;
let result = expr.evaluate(&batch)?.into_array(batch.num_rows());
let result = result
.as_any()
@@ -561,7 +510,7 @@ mod tests {
let then2 = lit(456i32);
let else_value = lit(999i32);
- let expr = case(
+ let expr = generate_case_when_with_type_coercion(
None,
vec![(when1, then1), (when2, then2)],
Some(else_value),
@@ -596,7 +545,12 @@ mod tests {
let then = lit(123.3f64);
let else_value = lit(999i32);
- let expr = case(None, vec![(when, then)], Some(else_value), schema.as_ref())?;
+ let expr = generate_case_when_with_type_coercion(
+ None,
+ vec![(when, then)],
+ Some(else_value),
+ schema.as_ref(),
+ )?;
let result = expr.evaluate(&batch)?.into_array(batch.num_rows());
let result = result
.as_any()
@@ -625,7 +579,12 @@ mod tests {
)?;
let then = col("load4", &schema)?;
- let expr = case(None, vec![(when, then)], None, schema.as_ref())?;
+ let expr = generate_case_when_with_type_coercion(
+ None,
+ vec![(when, then)],
+ None,
+ schema.as_ref(),
+ )?;
let result = expr.evaluate(&batch)?.into_array(batch.num_rows());
let result = result
.as_any()
@@ -650,7 +609,12 @@ mod tests {
let when = lit(1.77f64);
let then = col("load4", &schema)?;
- let expr = case(Some(expr), vec![(when, then)], None, schema.as_ref())?;
+ let expr = generate_case_when_with_type_coercion(
+ Some(expr),
+ vec![(when, then)],
+ None,
+ schema.as_ref(),
+ )?;
let result = expr.evaluate(&batch)?.into_array(batch.num_rows());
let result = result
.as_any()
@@ -724,7 +688,7 @@ mod tests {
)?;
let then2 = lit(true);
- let expr = case(
+ let expr = generate_case_when_with_type_coercion(
None,
vec![(when1, then1), (when2, then2)],
None,
@@ -752,7 +716,7 @@ mod tests {
let then2 = lit(456i64);
let else_expr = lit(1.23f64);
- let expr = case(
+ let expr = generate_case_when_with_type_coercion(
None,
vec![(when1, then1), (when2, then2)],
Some(else_expr),
@@ -763,4 +727,66 @@ mod tests {
assert_eq!(DataType::Float64, result_type);
Ok(())
}
+
+ fn generate_case_when_with_type_coercion(
+ expr: Option<Arc<dyn PhysicalExpr>>,
+ when_thens: Vec<WhenThen>,
+ else_expr: Option<Arc<dyn PhysicalExpr>>,
+ input_schema: &Schema,
+ ) -> Result<Arc<dyn PhysicalExpr>> {
+ let coerce_type =
+ get_case_common_type(&when_thens, else_expr.clone(), input_schema);
+ let (when_thens, else_expr) = match coerce_type {
+ None => Err(DataFusionError::Plan(format!(
+ "Can't get a common type for then {:?} and else {:?} expression",
+ when_thens, else_expr
+ ))),
+ Some(data_type) => {
+ // cast then expr
+ let left = when_thens
+ .into_iter()
+ .map(|(when, then)| {
+ let then = try_cast(then, input_schema, data_type.clone())?;
+ Ok((when, then))
+ })
+ .collect::<Result<Vec<_>>>()?;
+ let right = match else_expr {
+ None => None,
+ Some(expr) => Some(try_cast(expr, input_schema, data_type.clone())?),
+ };
+
+ Ok((left, right))
+ }
+ }?;
+ case(expr, when_thens, else_expr)
+ }
+
+ fn get_case_common_type(
+ when_thens: &[WhenThen],
+ else_expr: Option<Arc<dyn PhysicalExpr>>,
+ input_schema: &Schema,
+ ) -> Option<DataType> {
+ let thens_type = when_thens
+ .iter()
+ .map(|when_then| {
+ let data_type = &when_then.1.data_type(input_schema).unwrap();
+ data_type.clone()
+ })
+ .collect::<Vec<_>>();
+ let else_type = match else_expr {
+ None => {
+ // case when then exprs must have one then value
+ thens_type[0].clone()
+ }
+ Some(else_phy_expr) => else_phy_expr.data_type(input_schema).unwrap(),
+ };
+ thens_type
+ .iter()
+ .fold(Some(else_type), |left, right_type| match left {
+ None => None,
+ // TODO: now just use the `equal` coercion rule for case when. If find the issue, and
+ // refactor again.
+ Some(left_type) => comparison_coercion(&left_type, right_type),
+ })
+ }
}
diff --git a/datafusion/physical-expr/src/planner.rs b/datafusion/physical-expr/src/planner.rs
index ba9664ef6..0964d6480 100644
--- a/datafusion/physical-expr/src/planner.rs
+++ b/datafusion/physical-expr/src/planner.rs
@@ -275,12 +275,7 @@ pub fn create_physical_expr(
} else {
None
};
- Ok(expressions::case(
- expr,
- when_then_expr,
- else_expr,
- input_schema,
- )?)
+ Ok(expressions::case(expr, when_then_expr, else_expr)?)
}
Expr::Cast { expr, data_type } => expressions::cast(
create_physical_expr(expr, input_dfschema, input_schema, execution_props)?,