You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by al...@apache.org on 2023/06/09 17:39:40 UTC

[arrow-datafusion] branch main updated: fix: remove type coercion of case expression in Expr::Schema (#6614)

This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new 005ecdc485 fix: remove type coercion of case expression in Expr::Schema (#6614)
005ecdc485 is described below

commit 005ecdc485a322ab97113616d94eb09c16446324
Author: jakevin <ja...@gmail.com>
AuthorDate: Sat Jun 10 01:39:33 2023 +0800

    fix: remove type coercion of case expression in Expr::Schema (#6614)
    
    * fix: remove type coercion of case expression in Expr::Schema
    
    * remove useless check, because we shouldn't
---
 datafusion/core/tests/sql/mod.rs                   |  9 ++-------
 .../core/tests/sqllogictests/test_files/scalar.slt |  8 ++++----
 .../core/tests/sqllogictests/test_files/union.slt  |  4 ++--
 datafusion/expr/src/expr_schema.rs                 | 22 +---------------------
 datafusion/optimizer/src/analyzer/type_coercion.rs | 11 +++++++++--
 5 files changed, 18 insertions(+), 36 deletions(-)

diff --git a/datafusion/core/tests/sql/mod.rs b/datafusion/core/tests/sql/mod.rs
index 0413a06b6a..943254ca46 100644
--- a/datafusion/core/tests/sql/mod.rs
+++ b/datafusion/core/tests/sql/mod.rs
@@ -971,13 +971,8 @@ async fn plan_and_collect(ctx: &SessionContext, sql: &str) -> Result<Vec<RecordB
 async fn execute_to_batches(ctx: &SessionContext, sql: &str) -> Vec<RecordBatch> {
     let df = ctx.sql(sql).await.unwrap();
 
-    // We are not really interested in the direct output of optimized_logical_plan
-    // since the physical plan construction already optimizes the given logical plan
-    // and we want to avoid double-optimization as a consequence. So we just construct
-    // it here to make sure that it doesn't fail at this step and get the optimized
-    // schema (to assert later that the logical and optimized schemas are the same).
-    let optimized = df.clone().into_optimized_plan().unwrap();
-    assert_eq!(df.logical_plan().schema(), optimized.schema());
+    // optimize just for check schema don't change during optimization.
+    df.clone().into_optimized_plan().unwrap();
 
     df.collect().await.unwrap()
 }
diff --git a/datafusion/core/tests/sqllogictests/test_files/scalar.slt b/datafusion/core/tests/sqllogictests/test_files/scalar.slt
index a4adbaefd4..2d1925702c 100644
--- a/datafusion/core/tests/sqllogictests/test_files/scalar.slt
+++ b/datafusion/core/tests/sqllogictests/test_files/scalar.slt
@@ -646,27 +646,27 @@ SELECT CASE WHEN NULL THEN 'foo' ELSE 'bar' END
 bar
 
 # case_expr_with_null()
-query I
+query ?
 select case when b is null then null else b end from (select a,b from (values (1,null),(2,3)) as t (a,b)) a;
 ----
 NULL
 3
 
-query I
+query ?
 select case when b is null then null else b end from (select a,b from (values (1,1),(2,3)) as t (a,b)) a;
 ----
 1
 3
 
 # case_expr_with_nulls()
-query I
+query ?
 select case when b is null then null when b < 3 then null when b >=3 then b + 1 else b end from (select a,b from (values (1,null),(1,2),(2,3)) as t (a,b)) a
 ----
 NULL
 NULL
 4
 
-query I
+query ?
 select case b when 1 then null when 2 then null when 3 then b + 1 else b end from (select a,b from (values (1,null),(1,2),(2,3)) as t (a,b)) a;
 ----
 NULL
diff --git a/datafusion/core/tests/sqllogictests/test_files/union.slt b/datafusion/core/tests/sqllogictests/test_files/union.slt
index 5779d5153e..2f33437ca1 100644
--- a/datafusion/core/tests/sqllogictests/test_files/union.slt
+++ b/datafusion/core/tests/sqllogictests/test_files/union.slt
@@ -432,13 +432,13 @@ logical_plan
 Sort: t1.c1 ASC NULLS LAST
 --Union
 ----TableScan: t1 projection=[c1]
-----Projection: t2.c1a AS t1.c1
+----Projection: t2.c1a AS c1
 ------TableScan: t2 projection=[c1a]
 physical_plan
 SortPreservingMergeExec: [c1@0 ASC NULLS LAST]
 --UnionExec
 ----CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c1], output_ordering=[c1@0 ASC NULLS LAST], has_header=true
-----ProjectionExec: expr=[c1a@0 as t1.c1]
+----ProjectionExec: expr=[c1a@0 as c1]
 ------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c1a], output_ordering=[c1a@0 ASC NULLS LAST], has_header=true
 
 statement ok
diff --git a/datafusion/expr/src/expr_schema.rs b/datafusion/expr/src/expr_schema.rs
index f502676625..3c68c4acd7 100644
--- a/datafusion/expr/src/expr_schema.rs
+++ b/datafusion/expr/src/expr_schema.rs
@@ -22,7 +22,6 @@ use crate::expr::{
 };
 use crate::field_util::get_indexed_field;
 use crate::type_coercion::binary::get_result_type;
-use crate::type_coercion::other::get_coerce_type_for_case_expression;
 use crate::{
     aggregate_function, function, window_function, LogicalPlan, Projection, Subquery,
 };
@@ -73,26 +72,7 @@ impl ExprSchemable for Expr {
             Expr::OuterReferenceColumn(ty, _) => Ok(ty.clone()),
             Expr::ScalarVariable(ty, _) => Ok(ty.clone()),
             Expr::Literal(l) => Ok(l.get_datatype()),
-            Expr::Case(case) => {
-                // https://github.com/apache/arrow-datafusion/issues/5821
-                // when #5681 will be fixed, this code can be reverted to:
-                // case.when_then_expr[0].1.get_type(schema)
-                let then_types = case
-                    .when_then_expr
-                    .iter()
-                    .map(|when_then| when_then.1.get_type(schema))
-                    .collect::<Result<Vec<_>>>()?;
-                let else_type = match &case.else_expr {
-                    None => Ok(None),
-                    Some(expr) => expr.get_type(schema).map(Some),
-                }?;
-                get_coerce_type_for_case_expression(&then_types, else_type.as_ref())
-                    .ok_or_else(|| {
-                        DataFusionError::Internal(String::from(
-                            "Cannot infer type for CASE statement",
-                        ))
-                    })
-            }
+            Expr::Case(case) => case.when_then_expr[0].1.get_type(schema),
             Expr::Cast(Cast { data_type, .. })
             | Expr::TryCast(TryCast { data_type, .. }) => Ok(data_type.clone()),
             Expr::ScalarUDF(ScalarUDF { fun, args }) => {
diff --git a/datafusion/optimizer/src/analyzer/type_coercion.rs b/datafusion/optimizer/src/analyzer/type_coercion.rs
index da033a54ae..0d0061a5e4 100644
--- a/datafusion/optimizer/src/analyzer/type_coercion.rs
+++ b/datafusion/optimizer/src/analyzer/type_coercion.rs
@@ -42,7 +42,7 @@ use datafusion_expr::utils::from_plan;
 use datafusion_expr::{
     aggregate_function, function, is_false, is_not_false, is_not_true, is_not_unknown,
     is_true, is_unknown, type_coercion, AggregateFunction, Expr, LogicalPlan, Operator,
-    WindowFrame, WindowFrameBound, WindowFrameUnits,
+    Projection, WindowFrame, WindowFrameBound, WindowFrameUnits,
 };
 use datafusion_expr::{ExprSchemable, Signature};
 
@@ -108,7 +108,14 @@ fn analyze_internal(
         })
         .collect::<Result<Vec<_>>>()?;
 
-    from_plan(plan, &new_expr, &new_inputs)
+    // TODO: use from_plan after fix https://github.com/apache/arrow-datafusion/issues/6613
+    match &plan {
+        LogicalPlan::Projection(_) => Ok(LogicalPlan::Projection(Projection::try_new(
+            new_expr,
+            Arc::new(new_inputs[0].clone()),
+        )?)),
+        _ => from_plan(plan, &new_expr, &new_inputs),
+    }
 }
 
 pub(crate) struct TypeCoercionRewriter {