You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by ja...@apache.org on 2023/04/05 02:41:40 UTC

[arrow-datafusion] branch main updated: fix: coerce type for InSubquery and fix timestamp minus timestamp. (#5853)

This is an automated email from the ASF dual-hosted git repository.

jakevin pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new e86c83eb07 fix: coerce type for InSubquery and fix timestamp minus timestamp. (#5853)
e86c83eb07 is described below

commit e86c83eb077867805be987f5f7f6318d37af5e76
Author: jakevin <ja...@gmail.com>
AuthorDate: Wed Apr 5 10:41:34 2023 +0800

    fix: coerce type for InSubquery and fix timestamp minus timestamp. (#5853)
    
    * fix: type_coercion
    
    * fix: coerce type for InSubquery and fix timestamp minus timestamp.
---
 benchmarks/src/bin/parquet.rs               |  2 +-
 datafusion/expr/src/operator.rs             | 15 ++++++
 datafusion/expr/src/type_coercion/binary.rs |  3 +-
 datafusion/expr/src/utils.rs                |  4 +-
 datafusion/optimizer/src/type_coercion.rs   | 82 ++++++++++++++++++++++++++++-
 5 files changed, 99 insertions(+), 7 deletions(-)

diff --git a/benchmarks/src/bin/parquet.rs b/benchmarks/src/bin/parquet.rs
index 658d924dfb..7ddc18c07a 100644
--- a/benchmarks/src/bin/parquet.rs
+++ b/benchmarks/src/bin/parquet.rs
@@ -288,7 +288,7 @@ async fn run_filter_benchmarks(opt: Opt, test_file: &TestParquetFile) -> Result<
                 let (rows, elapsed) =
                     exec_scan(&ctx, test_file, filter_expr.clone(), opt.debug).await?;
                 let ms = elapsed.as_secs_f64() * 1000.0;
-                println!("Iteration {} returned {} rows in {ms} ms", i, rows);
+                println!("Iteration {i} returned {rows} rows in {ms} ms");
                 rundata.write_iter(elapsed, rows);
             }
         }
diff --git a/datafusion/expr/src/operator.rs b/datafusion/expr/src/operator.rs
index b0437eabdd..659e0d1af3 100644
--- a/datafusion/expr/src/operator.rs
+++ b/datafusion/expr/src/operator.rs
@@ -110,6 +110,21 @@ impl Operator {
         }
     }
 
+    /// Return true if the operator is a numerical operator.
+    ///
+    /// For example, 'Binary(a, +, b)' would be a numerical expression.
+    /// PostgresSQL concept: https://www.postgresql.org/docs/7.0/operators2198.htm
+    pub fn is_numerical_operators(&self) -> bool {
+        matches!(
+            self,
+            Operator::Plus
+                | Operator::Minus
+                | Operator::Multiply
+                | Operator::Divide
+                | Operator::Modulo
+        )
+    }
+
     /// Return true if the operator is a comparison operator.
     ///
     /// For example, 'Binary(a, >, b)' would be a comparison expression.
diff --git a/datafusion/expr/src/type_coercion/binary.rs b/datafusion/expr/src/type_coercion/binary.rs
index 9c6eefe9e7..99a5efb4fd 100644
--- a/datafusion/expr/src/type_coercion/binary.rs
+++ b/datafusion/expr/src/type_coercion/binary.rs
@@ -235,8 +235,7 @@ pub fn temporal_add_sub_coercion(
         // if two date/timestamp are being added/subtracted, return an error indicating that the operation is not supported
         (lhs, rhs, _) if (is_date(lhs) || is_timestamp(lhs)) && (is_date(rhs) || is_timestamp(rhs)) => {
             Err(DataFusionError::Plan(format!(
-                "{:?} {:?} is an unsupported operation. addition/subtraction on dates/timestamps only supported with interval types",
-                lhs_type, rhs_type
+                "{lhs_type:?} {rhs_type:?} is an unsupported operation. addition/subtraction on dates/timestamps only supported with interval types"
             )))
         }
         // return None if no coercion is possible
diff --git a/datafusion/expr/src/utils.rs b/datafusion/expr/src/utils.rs
index f3975e8346..a61e79fc24 100644
--- a/datafusion/expr/src/utils.rs
+++ b/datafusion/expr/src/utils.rs
@@ -112,7 +112,7 @@ fn powerset<T>(slice: &[T]) -> Result<Vec<Vec<&T>>, String> {
 fn check_grouping_set_size_limit(size: usize) -> Result<()> {
     let max_grouping_set_size = 65535;
     if size > max_grouping_set_size {
-        return Err(DataFusionError::Plan(format!("The number of group_expression in grouping_set exceeds the maximum limit {}, found {}", max_grouping_set_size, size)));
+        return Err(DataFusionError::Plan(format!("The number of group_expression in grouping_set exceeds the maximum limit {max_grouping_set_size}, found {size}")));
     }
 
     Ok(())
@@ -122,7 +122,7 @@ fn check_grouping_set_size_limit(size: usize) -> Result<()> {
 fn check_grouping_sets_size_limit(size: usize) -> Result<()> {
     let max_grouping_sets_size = 4096;
     if size > max_grouping_sets_size {
-        return Err(DataFusionError::Plan(format!("The number of grouping_set in grouping_sets exceeds the maximum limit {}, found {}", max_grouping_sets_size, size)));
+        return Err(DataFusionError::Plan(format!("The number of grouping_set in grouping_sets exceeds the maximum limit {max_grouping_sets_size}, found {size}")));
     }
 
     Ok(())
diff --git a/datafusion/optimizer/src/type_coercion.rs b/datafusion/optimizer/src/type_coercion.rs
index 0c24c5b877..a931da4b3e 100644
--- a/datafusion/optimizer/src/type_coercion.rs
+++ b/datafusion/optimizer/src/type_coercion.rs
@@ -150,7 +150,14 @@ impl TreeNodeRewriter for TypeCoercionRewriter {
                 subquery,
                 negated,
             } => {
+                let expr_type = expr.get_type(&self.schema)?;
                 let new_plan = optimize_internal(&self.schema, &subquery.subquery)?;
+                let subquery_type = new_plan.schema().field(0).data_type();
+                let expr = if &expr_type == subquery_type {
+                    expr
+                } else {
+                    Box::new(expr.cast_to(subquery_type, &self.schema)?)
+                };
                 Ok(Expr::InSubquery {
                     expr,
                     subquery: Subquery {
@@ -250,6 +257,17 @@ impl TreeNodeRewriter for TypeCoercionRewriter {
                         // this is a workaround for https://github.com/apache/arrow-datafusion/issues/3419
                         Ok(expr.clone())
                     }
+                    (DataType::Timestamp(_, _), DataType::Timestamp(_, _))
+                        if op.is_numerical_operators() =>
+                    {
+                        if matches!(op, Operator::Minus) {
+                            Ok(expr)
+                        } else {
+                            Err(DataFusionError::Internal(format!(
+                                "Unsupported operation {op:?} between {left_type:?} and {right_type:?}"
+                            )))
+                        }
+                    }
                     _ => {
                         let coerced_type = coerce_types(&left_type, &op, &right_type)?;
                         let expr = Expr::BinaryExpr(BinaryExpr::new(
@@ -718,8 +736,8 @@ mod test {
     use datafusion_expr::{
         cast, col, concat, concat_ws, create_udaf, is_true,
         AccumulatorFunctionImplementation, AggregateFunction, AggregateUDF, BinaryExpr,
-        BuiltinScalarFunction, Case, ColumnarValue, ExprSchemable, Operator,
-        StateTypeFunction,
+        BuiltinScalarFunction, Case, ColumnarValue, ExprSchemable, Filter, Operator,
+        StateTypeFunction, Subquery,
     };
     use datafusion_expr::{
         lit,
@@ -1411,4 +1429,64 @@ mod test {
         assert_optimized_plan_eq(&plan, expected)?;
         Ok(())
     }
+
+    #[test]
+    fn timestamp_subtract_timestamp() -> Result<()> {
+        let expr = Expr::BinaryExpr(BinaryExpr::new(
+            Box::new(cast(
+                lit("1998-03-18"),
+                DataType::Timestamp(arrow::datatypes::TimeUnit::Nanosecond, None),
+            )),
+            Operator::Minus,
+            Box::new(cast(
+                lit("1998-03-18"),
+                DataType::Timestamp(arrow::datatypes::TimeUnit::Nanosecond, None),
+            )),
+        ));
+        let empty = Arc::new(LogicalPlan::EmptyRelation(EmptyRelation {
+            produce_one_row: false,
+            schema: Arc::new(DFSchema::empty()),
+        }));
+        let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?);
+        dbg!(&plan);
+        let expected =
+            "Projection: CAST(Utf8(\"1998-03-18\") AS Timestamp(Nanosecond, None)) - CAST(Utf8(\"1998-03-18\") AS Timestamp(Nanosecond, None))\n  EmptyRelation";
+        assert_optimized_plan_eq(&plan, expected)?;
+        Ok(())
+    }
+
+    #[test]
+    fn in_subquery() -> Result<()> {
+        let empty_inside = Arc::new(LogicalPlan::EmptyRelation(EmptyRelation {
+            produce_one_row: false,
+            schema: Arc::new(DFSchema::new_with_metadata(
+                vec![DFField::new_unqualified("a_int32", DataType::Int32, true)],
+                std::collections::HashMap::new(),
+            )?),
+        }));
+        let empty_outside = Arc::new(LogicalPlan::EmptyRelation(EmptyRelation {
+            produce_one_row: false,
+            schema: Arc::new(DFSchema::new_with_metadata(
+                vec![DFField::new_unqualified("a_int64", DataType::Int64, true)],
+                std::collections::HashMap::new(),
+            )?),
+        }));
+        let in_subquery_expr = Expr::InSubquery {
+            expr: Box::new(col("a_int64")),
+            subquery: Subquery {
+                subquery: empty_inside,
+                outer_ref_columns: vec![],
+            },
+            negated: false,
+        };
+        let plan = LogicalPlan::Filter(Filter::try_new(in_subquery_expr, empty_outside)?);
+        // add cast for
+        let expected = "\
+        Filter: CAST(a_int64 AS Int32) IN (<subquery>)\
+        \n  Subquery:\
+        \n    EmptyRelation\
+        \n  EmptyRelation";
+        assert_optimized_plan_eq(&plan, expected)?;
+        Ok(())
+    }
 }