You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by li...@apache.org on 2022/10/04 09:38:24 UTC

[arrow-datafusion] branch master updated: move `type coercion` for case when expr (#3676)

This is an automated email from the ASF dual-hosted git repository.

liukun pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/master by this push:
     new 01117325d move `type coercion` for case when expr (#3676)
01117325d is described below

commit 01117325da728f094f235a77b4acec9a7c9bc7e9
Author: Kun Liu <li...@apache.org>
AuthorDate: Tue Oct 4 17:38:18 2022 +0800

    move `type coercion` for case when expr (#3676)
    
    * support type coercion in logical phase and remove it in the physical phase
    
    * Apply suggestions from code review
    
    Co-authored-by: Andrew Lamb <an...@nerdnetworks.org>
    
    * format code
    
    * change error message
    
    Co-authored-by: Andrew Lamb <an...@nerdnetworks.org>
---
 datafusion/core/tests/sql/expr.rs                |  52 ++++----
 datafusion/core/tests/sql/projection.rs          |  14 +-
 datafusion/optimizer/src/type_coercion.rs        |  66 +++++++++
 datafusion/physical-expr/src/expressions/case.rs | 162 +++++++++++++----------
 datafusion/physical-expr/src/planner.rs          |   7 +-
 5 files changed, 194 insertions(+), 107 deletions(-)

diff --git a/datafusion/core/tests/sql/expr.rs b/datafusion/core/tests/sql/expr.rs
index b1e8227e0..93e453101 100644
--- a/datafusion/core/tests/sql/expr.rs
+++ b/datafusion/core/tests/sql/expr.rs
@@ -152,12 +152,12 @@ async fn case_expr_with_null() -> Result<()> {
     let actual = execute_to_batches(&ctx, sql).await;
 
     let expected = vec![
-        "+------------------------------------------------+",
-        "| CASE WHEN #a.b IS NULL THEN NULL ELSE #a.b END |",
-        "+------------------------------------------------+",
-        "|                                                |",
-        "| 3                                              |",
-        "+------------------------------------------------+",
+        "+----------------------------------------------+",
+        "| CASE WHEN a.b IS NULL THEN NULL ELSE a.b END |",
+        "+----------------------------------------------+",
+        "|                                              |",
+        "| 3                                            |",
+        "+----------------------------------------------+",
     ];
     assert_batches_eq!(expected, &actual);
 
@@ -165,12 +165,12 @@ async fn case_expr_with_null() -> Result<()> {
     let actual = execute_to_batches(&ctx, sql).await;
 
     let expected = vec![
-        "+------------------------------------------------+",
-        "| CASE WHEN #a.b IS NULL THEN NULL ELSE #a.b END |",
-        "+------------------------------------------------+",
-        "| 1                                              |",
-        "| 3                                              |",
-        "+------------------------------------------------+",
+        "+----------------------------------------------+",
+        "| CASE WHEN a.b IS NULL THEN NULL ELSE a.b END |",
+        "+----------------------------------------------+",
+        "| 1                                            |",
+        "| 3                                            |",
+        "+----------------------------------------------+",
     ];
     assert_batches_eq!(expected, &actual);
 
@@ -184,13 +184,13 @@ async fn case_expr_with_nulls() -> Result<()> {
     let actual = execute_to_batches(&ctx, sql).await;
 
     let expected = vec![
-        "+--------------------------------------------------------------------------------------------------------------------------+",
-        "| CASE WHEN #a.b IS NULL THEN NULL WHEN #a.b < Int64(3) THEN NULL WHEN #a.b >= Int64(3) THEN #a.b + Int64(1) ELSE #a.b END |",
-        "+--------------------------------------------------------------------------------------------------------------------------+",
-        "|                                                                                                                          |",
-        "|                                                                                                                          |",
-        "| 4                                                                                                                        |",
-        "+--------------------------------------------------------------------------------------------------------------------------+"
+        "+---------------------------------------------------------------------------------------------------------------------+",
+        "| CASE WHEN a.b IS NULL THEN NULL WHEN a.b < Int64(3) THEN NULL WHEN a.b >= Int64(3) THEN a.b + Int64(1) ELSE a.b END |",
+        "+---------------------------------------------------------------------------------------------------------------------+",
+        "|                                                                                                                     |",
+        "|                                                                                                                     |",
+        "| 4                                                                                                                   |",
+        "+---------------------------------------------------------------------------------------------------------------------+",
     ];
     assert_batches_eq!(expected, &actual);
 
@@ -198,13 +198,13 @@ async fn case_expr_with_nulls() -> Result<()> {
     let actual = execute_to_batches(&ctx, sql).await;
 
     let expected = vec![
-        "+------------------------------------------------------------------------------------------------------------+",
-        "| CASE #a.b WHEN Int64(1) THEN NULL WHEN Int64(2) THEN NULL WHEN Int64(3) THEN #a.b + Int64(1) ELSE #a.b END |",
-        "+------------------------------------------------------------------------------------------------------------+",
-        "|                                                                                                            |",
-        "|                                                                                                            |",
-        "| 4                                                                                                          |",
-        "+------------------------------------------------------------------------------------------------------------+",
+        "+---------------------------------------------------------------------------------------------------------+",
+        "| CASE a.b WHEN Int64(1) THEN NULL WHEN Int64(2) THEN NULL WHEN Int64(3) THEN a.b + Int64(1) ELSE a.b END |",
+        "+---------------------------------------------------------------------------------------------------------+",
+        "|                                                                                                         |",
+        "|                                                                                                         |",
+        "| 4                                                                                                       |",
+        "+---------------------------------------------------------------------------------------------------------+",
     ];
     assert_batches_eq!(expected, &actual);
 
diff --git a/datafusion/core/tests/sql/projection.rs b/datafusion/core/tests/sql/projection.rs
index 97c6dcf8a..cee8e706c 100644
--- a/datafusion/core/tests/sql/projection.rs
+++ b/datafusion/core/tests/sql/projection.rs
@@ -252,13 +252,13 @@ async fn project_cast_dictionary() {
     let actual = collect(physical_plan, ctx.task_ctx()).await.unwrap();
 
     let expected = vec![
-        "+------------------------------------------------------------------------------------+",
-        "| CASE WHEN #cpu_load_short.host IS NULL THEN Utf8(\"\") ELSE #cpu_load_short.host END |",
-        "+------------------------------------------------------------------------------------+",
-        "| host1                                                                              |",
-        "|                                                                                    |",
-        "| host2                                                                              |",
-        "+------------------------------------------------------------------------------------+",
+        "+----------------------------------------------------------------------------------+",
+        "| CASE WHEN cpu_load_short.host IS NULL THEN Utf8(\"\") ELSE cpu_load_short.host END |",
+        "+----------------------------------------------------------------------------------+",
+        "| host1                                                                            |",
+        "|                                                                                  |",
+        "| host2                                                                            |",
+        "+----------------------------------------------------------------------------------+",
     ];
     assert_batches_eq!(expected, &actual);
 }
diff --git a/datafusion/optimizer/src/type_coercion.rs b/datafusion/optimizer/src/type_coercion.rs
index 5a53adc26..c7f107b5e 100644
--- a/datafusion/optimizer/src/type_coercion.rs
+++ b/datafusion/optimizer/src/type_coercion.rs
@@ -354,6 +354,50 @@ impl ExprRewriter for TypeCoercionRewriter {
                     }
                 }
             }
+            Expr::Case {
+                expr,
+                when_then_expr,
+                else_expr,
+            } => {
+                // all the result of then and else should be convert to a common data type,
+                // if they can be coercible to a common data type, return error.
+                let then_types = when_then_expr
+                    .iter()
+                    .map(|when_then| when_then.1.get_type(&self.schema))
+                    .collect::<Result<Vec<_>>>()?;
+                let else_type = match &else_expr {
+                    None => Ok(None),
+                    Some(expr) => expr.get_type(&self.schema).map(Some),
+                }?;
+                let case_when_coerce_type =
+                    get_coerce_type_for_case_when(&then_types, &else_type);
+                match case_when_coerce_type {
+                    None => Err(DataFusionError::Internal(format!(
+                        "Failed to coerce then ({:?}) and else ({:?}) to common types in CASE WHEN expression",
+                        then_types, else_type
+                    ))),
+                    Some(data_type) => {
+                        let left = when_then_expr
+                            .into_iter()
+                            .map(|(when, then)| {
+                                let then = then.cast_to(&data_type, &self.schema)?;
+                                Ok((when, Box::new(then)))
+                            })
+                            .collect::<Result<Vec<_>>>()?;
+                        let right = match else_expr {
+                            None => None,
+                            Some(expr) => {
+                                Some(Box::new(expr.cast_to(&data_type, &self.schema)?))
+                            }
+                        };
+                        Ok(Expr::Case {
+                            expr,
+                            when_then_expr: left,
+                            else_expr: right,
+                        })
+                    }
+                }
+            }
             expr => Ok(expr),
         }
     }
@@ -410,6 +454,28 @@ fn coerce_arguments_for_signature(
         .collect::<Result<Vec<_>>>()
 }
 
+/// Find a common coerceable type for all `then_types` as well
+/// and the `else_type`, if specified.
+/// Returns the common data type for `then_types` and `else_type`
+fn get_coerce_type_for_case_when(
+    then_types: &[DataType],
+    else_type: &Option<DataType>,
+) -> Option<DataType> {
+    let else_type = match else_type {
+        None => then_types[0].clone(),
+        Some(data_type) => data_type.clone(),
+    };
+    then_types
+        .iter()
+        .fold(Some(else_type), |left, right_type| match left {
+            // failed to find a valid coercion in a previous iteration
+            None => None,
+            // TODO: now just use the `equal` coercion rule for case when. If find the issue, and
+            // refactor again.
+            Some(left_type) => comparison_coercion(&left_type, right_type),
+        })
+}
+
 #[cfg(test)]
 mod test {
     use crate::type_coercion::{TypeCoercion, TypeCoercionRewriter};
diff --git a/datafusion/physical-expr/src/expressions/case.rs b/datafusion/physical-expr/src/expressions/case.rs
index cf4f7defe..b1bd0a604 100644
--- a/datafusion/physical-expr/src/expressions/case.rs
+++ b/datafusion/physical-expr/src/expressions/case.rs
@@ -25,7 +25,6 @@ use arrow::compute::{and, eq_dyn, is_null, not, or, or_kleene};
 use arrow::datatypes::{DataType, Schema};
 use arrow::record_batch::RecordBatch;
 use datafusion_common::{DataFusionError, Result};
-use datafusion_expr::binary_rule::comparison_coercion;
 use datafusion_expr::ColumnarValue;
 
 type WhenThen = (Arc<dyn PhysicalExpr>, Arc<dyn PhysicalExpr>);
@@ -294,66 +293,10 @@ pub fn case(
     expr: Option<Arc<dyn PhysicalExpr>>,
     when_thens: Vec<WhenThen>,
     else_expr: Option<Arc<dyn PhysicalExpr>>,
-    input_schema: &Schema,
 ) -> Result<Arc<dyn PhysicalExpr>> {
-    // all the result of then and else should be convert to a common data type,
-    // if they can be coercible to a common data type, return error.
-    let coerce_type = get_case_common_type(&when_thens, else_expr.clone(), input_schema);
-    let (when_thens, else_expr) = match coerce_type {
-        None => Err(DataFusionError::Plan(format!(
-            "Can't get a common type for then {:?} and else {:?} expression",
-            when_thens, else_expr
-        ))),
-        Some(data_type) => {
-            // cast then expr
-            let left = when_thens
-                .into_iter()
-                .map(|(when, then)| {
-                    let then = try_cast(then, input_schema, data_type.clone())?;
-                    Ok((when, then))
-                })
-                .collect::<Result<Vec<_>>>()?;
-            let right = match else_expr {
-                None => None,
-                Some(expr) => Some(try_cast(expr, input_schema, data_type.clone())?),
-            };
-
-            Ok((left, right))
-        }
-    }?;
-
     Ok(Arc::new(CaseExpr::try_new(expr, when_thens, else_expr)?))
 }
 
-fn get_case_common_type(
-    when_thens: &[WhenThen],
-    else_expr: Option<Arc<dyn PhysicalExpr>>,
-    input_schema: &Schema,
-) -> Option<DataType> {
-    let thens_type = when_thens
-        .iter()
-        .map(|when_then| {
-            let data_type = &when_then.1.data_type(input_schema).unwrap();
-            data_type.clone()
-        })
-        .collect::<Vec<_>>();
-    let else_type = match else_expr {
-        None => {
-            // case when then exprs must have one then value
-            thens_type[0].clone()
-        }
-        Some(else_phy_expr) => else_phy_expr.data_type(input_schema).unwrap(),
-    };
-    thens_type
-        .iter()
-        .fold(Some(else_type), |left, right_type| match left {
-            None => None,
-            // TODO: now just use the `equal` coercion rule for case when. If find the issue, and
-            // refactor again.
-            Some(left_type) => comparison_coercion(&left_type, right_type),
-        })
-}
-
 #[cfg(test)]
 mod tests {
     use super::*;
@@ -365,6 +308,7 @@ mod tests {
     use arrow::datatypes::DataType::Float64;
     use arrow::datatypes::*;
     use datafusion_common::ScalarValue;
+    use datafusion_expr::binary_rule::comparison_coercion;
     use datafusion_expr::Operator;
 
     #[test]
@@ -378,7 +322,7 @@ mod tests {
         let when2 = lit("bar");
         let then2 = lit(456i32);
 
-        let expr = case(
+        let expr = generate_case_when_with_type_coercion(
             Some(col("a", &schema)?),
             vec![(when1, then1), (when2, then2)],
             None,
@@ -409,7 +353,7 @@ mod tests {
         let then2 = lit(456i32);
         let else_value = lit(999i32);
 
-        let expr = case(
+        let expr = generate_case_when_with_type_coercion(
             Some(col("a", &schema)?),
             vec![(when1, then1), (when2, then2)],
             Some(else_value),
@@ -444,7 +388,7 @@ mod tests {
             &batch.schema(),
         )?;
 
-        let expr = case(
+        let expr = generate_case_when_with_type_coercion(
             Some(col("a", &schema)?),
             vec![(when1, then1)],
             Some(else_value),
@@ -484,7 +428,7 @@ mod tests {
         )?;
         let then2 = lit(456i32);
 
-        let expr = case(
+        let expr = generate_case_when_with_type_coercion(
             None,
             vec![(when1, then1), (when2, then2)],
             None,
@@ -518,7 +462,12 @@ mod tests {
         )?;
         let x = lit(ScalarValue::Float64(None));
 
-        let expr = case(None, vec![(when1, then1)], Some(x), schema.as_ref())?;
+        let expr = generate_case_when_with_type_coercion(
+            None,
+            vec![(when1, then1)],
+            Some(x),
+            schema.as_ref(),
+        )?;
         let result = expr.evaluate(&batch)?.into_array(batch.num_rows());
         let result = result
             .as_any()
@@ -561,7 +510,7 @@ mod tests {
         let then2 = lit(456i32);
         let else_value = lit(999i32);
 
-        let expr = case(
+        let expr = generate_case_when_with_type_coercion(
             None,
             vec![(when1, then1), (when2, then2)],
             Some(else_value),
@@ -596,7 +545,12 @@ mod tests {
         let then = lit(123.3f64);
         let else_value = lit(999i32);
 
-        let expr = case(None, vec![(when, then)], Some(else_value), schema.as_ref())?;
+        let expr = generate_case_when_with_type_coercion(
+            None,
+            vec![(when, then)],
+            Some(else_value),
+            schema.as_ref(),
+        )?;
         let result = expr.evaluate(&batch)?.into_array(batch.num_rows());
         let result = result
             .as_any()
@@ -625,7 +579,12 @@ mod tests {
         )?;
         let then = col("load4", &schema)?;
 
-        let expr = case(None, vec![(when, then)], None, schema.as_ref())?;
+        let expr = generate_case_when_with_type_coercion(
+            None,
+            vec![(when, then)],
+            None,
+            schema.as_ref(),
+        )?;
         let result = expr.evaluate(&batch)?.into_array(batch.num_rows());
         let result = result
             .as_any()
@@ -650,7 +609,12 @@ mod tests {
         let when = lit(1.77f64);
         let then = col("load4", &schema)?;
 
-        let expr = case(Some(expr), vec![(when, then)], None, schema.as_ref())?;
+        let expr = generate_case_when_with_type_coercion(
+            Some(expr),
+            vec![(when, then)],
+            None,
+            schema.as_ref(),
+        )?;
         let result = expr.evaluate(&batch)?.into_array(batch.num_rows());
         let result = result
             .as_any()
@@ -724,7 +688,7 @@ mod tests {
         )?;
         let then2 = lit(true);
 
-        let expr = case(
+        let expr = generate_case_when_with_type_coercion(
             None,
             vec![(when1, then1), (when2, then2)],
             None,
@@ -752,7 +716,7 @@ mod tests {
         let then2 = lit(456i64);
         let else_expr = lit(1.23f64);
 
-        let expr = case(
+        let expr = generate_case_when_with_type_coercion(
             None,
             vec![(when1, then1), (when2, then2)],
             Some(else_expr),
@@ -763,4 +727,66 @@ mod tests {
         assert_eq!(DataType::Float64, result_type);
         Ok(())
     }
+
+    fn generate_case_when_with_type_coercion(
+        expr: Option<Arc<dyn PhysicalExpr>>,
+        when_thens: Vec<WhenThen>,
+        else_expr: Option<Arc<dyn PhysicalExpr>>,
+        input_schema: &Schema,
+    ) -> Result<Arc<dyn PhysicalExpr>> {
+        let coerce_type =
+            get_case_common_type(&when_thens, else_expr.clone(), input_schema);
+        let (when_thens, else_expr) = match coerce_type {
+            None => Err(DataFusionError::Plan(format!(
+                "Can't get a common type for then {:?} and else {:?} expression",
+                when_thens, else_expr
+            ))),
+            Some(data_type) => {
+                // cast then expr
+                let left = when_thens
+                    .into_iter()
+                    .map(|(when, then)| {
+                        let then = try_cast(then, input_schema, data_type.clone())?;
+                        Ok((when, then))
+                    })
+                    .collect::<Result<Vec<_>>>()?;
+                let right = match else_expr {
+                    None => None,
+                    Some(expr) => Some(try_cast(expr, input_schema, data_type.clone())?),
+                };
+
+                Ok((left, right))
+            }
+        }?;
+        case(expr, when_thens, else_expr)
+    }
+
+    fn get_case_common_type(
+        when_thens: &[WhenThen],
+        else_expr: Option<Arc<dyn PhysicalExpr>>,
+        input_schema: &Schema,
+    ) -> Option<DataType> {
+        let thens_type = when_thens
+            .iter()
+            .map(|when_then| {
+                let data_type = &when_then.1.data_type(input_schema).unwrap();
+                data_type.clone()
+            })
+            .collect::<Vec<_>>();
+        let else_type = match else_expr {
+            None => {
+                // case when then exprs must have one then value
+                thens_type[0].clone()
+            }
+            Some(else_phy_expr) => else_phy_expr.data_type(input_schema).unwrap(),
+        };
+        thens_type
+            .iter()
+            .fold(Some(else_type), |left, right_type| match left {
+                None => None,
+                // TODO: now just use the `equal` coercion rule for case when. If find the issue, and
+                // refactor again.
+                Some(left_type) => comparison_coercion(&left_type, right_type),
+            })
+    }
 }
diff --git a/datafusion/physical-expr/src/planner.rs b/datafusion/physical-expr/src/planner.rs
index ba9664ef6..0964d6480 100644
--- a/datafusion/physical-expr/src/planner.rs
+++ b/datafusion/physical-expr/src/planner.rs
@@ -275,12 +275,7 @@ pub fn create_physical_expr(
             } else {
                 None
             };
-            Ok(expressions::case(
-                expr,
-                when_then_expr,
-                else_expr,
-                input_schema,
-            )?)
+            Ok(expressions::case(expr, when_then_expr, else_expr)?)
         }
         Expr::Cast { expr, data_type } => expressions::cast(
             create_physical_expr(expr, input_dfschema, input_schema, execution_props)?,