You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by tu...@apache.org on 2023/06/28 08:29:16 UTC
[arrow-datafusion] branch main updated: Cleanup type coercion (#3419) (#6778)
This is an automated email from the ASF dual-hosted git repository.
tustvold pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new b9ecfc517b Cleanup type coercion (#3419) (#6778)
b9ecfc517b is described below
commit b9ecfc517bbf5121af1c64f415568322aca1f290
Author: Raphael Taylor-Davies <17...@users.noreply.github.com>
AuthorDate: Wed Jun 28 09:29:10 2023 +0100
Cleanup type coercion (#3419) (#6778)
* Cleanup type coercion (#3419)
* Further fixes
* Tweak doc
* Review feedback
---
datafusion/core/src/physical_planner.rs | 19 +-
.../core/tests/sqllogictests/test_files/dates.slt | 2 +-
.../tests/sqllogictests/test_files/interval.slt | 34 +-
.../tests/sqllogictests/test_files/timestamps.slt | 2 +-
.../sqllogictests/test_files/type_coercion.slt | 8 +-
datafusion/expr/src/type_coercion/binary.rs | 523 ++++++++++-----------
datafusion/optimizer/src/analyzer/type_coercion.rs | 94 +---
datafusion/physical-expr/src/expressions/binary.rs | 67 +--
8 files changed, 339 insertions(+), 410 deletions(-)
diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs
index 75566208e3..da2f396e8e 100644
--- a/datafusion/core/src/physical_planner.rs
+++ b/datafusion/core/src/physical_planner.rs
@@ -1956,7 +1956,7 @@ mod tests {
use fmt::Debug;
use std::collections::HashMap;
use std::convert::TryFrom;
- use std::ops::Not;
+ use std::ops::{BitAnd, Not};
use std::{any::Any, fmt};
fn make_session_state() -> SessionState {
@@ -2140,18 +2140,17 @@ mod tests {
async fn errors() -> Result<()> {
let bool_expr = col("c1").eq(col("c1"));
let cases = vec![
- // utf8 AND utf8
- col("c1").and(col("c1")),
+ // utf8 = utf8
+ col("c1").eq(col("c1")),
// u8 AND u8
- col("c3").and(col("c3")),
- // utf8 = bool
- col("c1").eq(bool_expr.clone()),
- // u32 AND bool
- col("c2").and(bool_expr),
+ col("c3").bitand(col("c3")),
+ // utf8 = u8
+ col("c1").eq(col("c3")),
+ // bool AND bool
+ bool_expr.clone().and(bool_expr),
];
for case in cases {
- let logical_plan = test_csv_scan().await?.project(vec![case.clone()]);
- assert!(logical_plan.is_ok());
+ test_csv_scan().await?.project(vec![case.clone()]).unwrap();
}
Ok(())
}
diff --git a/datafusion/core/tests/sqllogictests/test_files/dates.slt b/datafusion/core/tests/sqllogictests/test_files/dates.slt
index 5b76739e95..c35f16bc03 100644
--- a/datafusion/core/tests/sqllogictests/test_files/dates.slt
+++ b/datafusion/core/tests/sqllogictests/test_files/dates.slt
@@ -85,7 +85,7 @@ g
h
## Plan error when compare Utf8 and timestamp in where clause
-statement error Error during planning: Timestamp\(Nanosecond, Some\("\+00:00"\)\) \+ Utf8 can't be evaluated because there isn't a common type to coerce the types to
+statement error DataFusion error: type_coercion\ncaused by\nError during planning: Cannot coerce arithmetic expression Timestamp\(Nanosecond, Some\("\+00:00"\)\) \+ Utf8 to valid types
select i_item_desc from test
where d3_date > now() + '5 days';
diff --git a/datafusion/core/tests/sqllogictests/test_files/interval.slt b/datafusion/core/tests/sqllogictests/test_files/interval.slt
index 889e2759b2..9dd56c4636 100644
--- a/datafusion/core/tests/sqllogictests/test_files/interval.slt
+++ b/datafusion/core/tests/sqllogictests/test_files/interval.slt
@@ -430,13 +430,15 @@ select '1 month'::interval + '1980-01-01T12:00:00'::timestamp;
----
1980-02-01T12:00:00
-# Exected error: interval (scalar) - date / timestamp (scalar)
-
-query error DataFusion error: type_coercion\ncaused by\nError during planning: Interval\(MonthDayNano\) \- Date32 can't be evaluated because there isn't a common type to coerce the types to
+query D
select '1 month'::interval - '1980-01-01'::date;
+----
+1979-12-01
-query error DataFusion error: type_coercion\ncaused by\nError during planning: Interval\(MonthDayNano\) \- Timestamp\(Nanosecond, None\) can't be evaluated because there isn't a common type to coerce the types to
+query P
select '1 month'::interval - '1980-01-01T12:00:00'::timestamp;
+----
+1979-12-01T12:00:00
# interval (array) + date / timestamp (array)
query D
@@ -454,11 +456,19 @@ select i + ts from t;
2000-02-01T00:01:00
# expected error interval (array) - date / timestamp (array)
-query error DataFusion error: type_coercion\ncaused by\nError during planning: Interval\(MonthDayNano\) \- Date32 can't be evaluated because there isn't a common type to coerce the types to
+query D
select i - d from t;
+----
+1979-12-01
+1990-09-30
+1980-01-02
-query error DataFusion error: type_coercion\ncaused by\nError during planning: Interval\(MonthDayNano\) \- Timestamp\(Nanosecond, None\) can't be evaluated because there isn't a common type to coerce the types to
+query P
select i - ts from t;
+----
+1999-12-01T00:00:00
+1999-12-31T12:11:10
+2000-01-31T23:59:00
# interval (scalar) + date / timestamp (array)
@@ -477,11 +487,19 @@ select '1 month'::interval + ts from t;
2000-03-01T00:00:00
# expected error interval (scalar) - date / timestamp (array)
-query error DataFusion error: type_coercion\ncaused by\nError during planning: Interval\(MonthDayNano\) \- Date32 can't be evaluated because there isn't a common type to coerce the types to
+query D
select '1 month'::interval - d from t;
+----
+1979-12-01
+1990-09-01
+1979-12-02
-query error DataFusion error: type_coercion\ncaused by\nError during planning: Interval\(MonthDayNano\) \- Timestamp\(Nanosecond, None\) can't be evaluated because there isn't a common type to coerce the types to
+query P
select '1 month'::interval - ts from t;
+----
+1999-12-01T00:00:00
+1999-12-01T12:11:10
+2000-01-01T00:00:00
# interval + date
query D
diff --git a/datafusion/core/tests/sqllogictests/test_files/timestamps.slt b/datafusion/core/tests/sqllogictests/test_files/timestamps.slt
index 3ba7c38f16..5250ce2399 100644
--- a/datafusion/core/tests/sqllogictests/test_files/timestamps.slt
+++ b/datafusion/core/tests/sqllogictests/test_files/timestamps.slt
@@ -1233,7 +1233,7 @@ SELECT ts1 + i FROM foo;
2003-07-12T01:31:15.000123463
# Timestamp + Timestamp => error
-query error DataFusion error: type_coercion\ncaused by\nInternal error: Unsupported operation Plus between Timestamp\(Nanosecond, None\) and Timestamp\(Nanosecond, None\)\. This was likely caused by a bug in DataFusion's code and we would welcome that you file an bug report in our issue tracker
+query error DataFusion error: Arrow error: Cast error: Cannot perform arithmetic operation between array of type Timestamp\(Nanosecond, None\) and array of type Timestamp\(Nanosecond, None\)
SELECT ts1 + ts2
FROM foo;
diff --git a/datafusion/core/tests/sqllogictests/test_files/type_coercion.slt b/datafusion/core/tests/sqllogictests/test_files/type_coercion.slt
index 9aced0a3fd..8b329df0c1 100644
--- a/datafusion/core/tests/sqllogictests/test_files/type_coercion.slt
+++ b/datafusion/core/tests/sqllogictests/test_files/type_coercion.slt
@@ -43,9 +43,13 @@ SELECT '2023-05-01 12:30:00'::timestamp - interval '1 month';
2023-04-01T12:30:00
# interval - date
-query error DataFusion error: type_coercion
+query D
select interval '1 month' - '2023-05-01'::date;
+----
+2023-04-01
# interval - timestamp
-query error DataFusion error: type_coercion
+query P
SELECT interval '1 month' - '2023-05-01 12:30:00'::timestamp;
+----
+2023-04-01T12:30:00
diff --git a/datafusion/expr/src/type_coercion/binary.rs b/datafusion/expr/src/type_coercion/binary.rs
index 7c9179b2f3..64ebf8b559 100644
--- a/datafusion/expr/src/type_coercion/binary.rs
+++ b/datafusion/expr/src/type_coercion/binary.rs
@@ -25,9 +25,144 @@ use arrow::datatypes::{
use datafusion_common::DataFusionError;
use datafusion_common::Result;
-use crate::type_coercion::{is_datetime, is_decimal, is_interval, is_numeric};
+use crate::type_coercion::is_numeric;
use crate::Operator;
+/// The type signature of an instantiation of binary expression
+struct Signature {
+ /// The type to coerce the left argument to
+ lhs: DataType,
+ /// The type to coerce the right argument to
+ rhs: DataType,
+ /// The return type of the expression
+ ret: DataType,
+}
+
+impl Signature {
+ /// A signature where the inputs are the same type as the output
+ fn uniform(t: DataType) -> Self {
+ Self {
+ lhs: t.clone(),
+ rhs: t.clone(),
+ ret: t,
+ }
+ }
+
+ /// A signature where the inputs are the same type with a boolean output
+ fn comparison(t: DataType) -> Self {
+ Self {
+ lhs: t.clone(),
+ rhs: t,
+ ret: DataType::Boolean,
+ }
+ }
+}
+
+/// Returns a [`Signature`] for applying `op` to arguments of type `lhs` and `rhs`
+fn signature(lhs: &DataType, op: &Operator, rhs: &DataType) -> Result<Signature> {
+ match op {
+ Operator::Eq |
+ Operator::NotEq |
+ Operator::Lt |
+ Operator::LtEq |
+ Operator::Gt |
+ Operator::GtEq |
+ Operator::IsDistinctFrom |
+ Operator::IsNotDistinctFrom => {
+ comparison_coercion(lhs, rhs).map(Signature::comparison).ok_or_else(|| {
+ DataFusionError::Plan(format!(
+ "Cannot infer common argument type for comparison operation {lhs} {op} {rhs}"
+ ))
+ })
+ }
+ Operator::And | Operator::Or => match (lhs, rhs) {
+ // logical binary boolean operators can only be evaluated in bools or nulls
+ (DataType::Boolean, DataType::Boolean)
+ | (DataType::Null, DataType::Null)
+ | (DataType::Boolean, DataType::Null)
+ | (DataType::Null, DataType::Boolean) => Ok(Signature::uniform(DataType::Boolean)),
+ _ => Err(DataFusionError::Plan(format!(
+ "Cannot infer common argument type for logical boolean operation {lhs} {op} {rhs}"
+ ))),
+ },
+ Operator::RegexMatch |
+ Operator::RegexIMatch |
+ Operator::RegexNotMatch |
+ Operator::RegexNotIMatch => {
+ regex_coercion(lhs, rhs).map(Signature::comparison).ok_or_else(|| {
+ DataFusionError::Plan(format!(
+ "Cannot infer common argument type for regex operation {lhs} {op} {rhs}"
+ ))
+ })
+ }
+ Operator::BitwiseAnd
+ | Operator::BitwiseOr
+ | Operator::BitwiseXor
+ | Operator::BitwiseShiftRight
+ | Operator::BitwiseShiftLeft => {
+ bitwise_coercion(lhs, rhs).map(Signature::uniform).ok_or_else(|| {
+ DataFusionError::Plan(format!(
+ "Cannot infer common type for bitwise operation {lhs} {op} {rhs}"
+ ))
+ })
+ }
+ Operator::StringConcat => {
+ string_concat_coercion(lhs, rhs).map(Signature::uniform).ok_or_else(|| {
+ DataFusionError::Plan(format!(
+ "Cannot infer common string type for string concat operation {lhs} {op} {rhs}"
+ ))
+ })
+ }
+ Operator::Plus |
+ Operator::Minus |
+ Operator::Multiply |
+ Operator::Divide|
+ Operator::Modulo => {
+ // TODO: this logic would be easier to follow if the functions were inlined
+ if let Some(ret) = mathematics_temporal_result_type(lhs, rhs) {
+ // Temporal arithmetic, e.g. Date32 + Interval
+ Ok(Signature{
+ lhs: lhs.clone(),
+ rhs: rhs.clone(),
+ ret,
+ })
+ } else if let Some(coerced) = temporal_coercion(lhs, rhs) {
+ // Temporal arithmetic by first coercing to a common time representation
+ // e.g. Date32 - Timestamp
+ let ret = mathematics_temporal_result_type(&coerced, &coerced).ok_or_else(|| {
+ DataFusionError::Plan(format!(
+ "Cannot get result type for temporal operation {coerced} {op} {coerced}"
+ ))
+ })?;
+ Ok(Signature{
+ lhs: coerced.clone(),
+ rhs: coerced,
+ ret,
+ })
+ } else if let Some((lhs, rhs)) = math_decimal_coercion(lhs, rhs) {
+ // Decimal arithmetic, e.g. Decimal(10, 2) + Decimal(10, 0)
+ let ret = decimal_op_mathematics_type(op, &lhs, &rhs).ok_or_else(|| {
+ DataFusionError::Plan(format!(
+ "Cannot get result type for decimal operation {lhs} {op} {rhs}"
+ ))
+ })?;
+ Ok(Signature{
+ lhs,
+ rhs,
+ ret,
+ })
+ } else if let Some(numeric) = mathematics_numerical_coercion(lhs, rhs) {
+ // Numeric arithmetic, e.g. Int32 + Int32
+ Ok(Signature::uniform(numeric))
+ } else {
+ Err(DataFusionError::Plan(format!(
+ "Cannot coerce arithmetic expression {lhs} {op} {rhs} to valid types"
+ )))
+ }
+ }
+ }
+}
+
/// Returns the result type of applying mathematics operations such as
/// `+` to arguments of `lhs_type` and `rhs_type`.
fn mathematics_temporal_result_type(
@@ -38,14 +173,6 @@ fn mathematics_temporal_result_type(
use arrow::datatypes::IntervalUnit::*;
use arrow::datatypes::TimeUnit::*;
- if !is_interval(lhs_type)
- && !is_interval(rhs_type)
- && !is_datetime(lhs_type)
- && !is_datetime(rhs_type)
- {
- return None;
- };
-
match (lhs_type, rhs_type) {
// datetime +/- interval
(Interval(_), Timestamp(_, _)) => Some(rhs_type.clone()),
@@ -66,194 +193,68 @@ fn mathematics_temporal_result_type(
| (Timestamp(Nanosecond, _), Timestamp(Nanosecond, _)) => {
Some(Interval(MonthDayNano))
}
- (Timestamp(_, _), Timestamp(_, _)) => None,
// date - date
(Date32, Date32) => Some(Interval(DayTime)),
(Date64, Date64) => Some(Interval(MonthDayNano)),
- (Date32, Date64) | (Date64, Date32) => Some(Interval(MonthDayNano)),
- // date - timestamp, timestamp - date
- (Date32, Timestamp(_, _))
- | (Timestamp(_, _), Date32)
- | (Date64, Timestamp(_, _))
- | (Timestamp(_, _), Date64) => {
- // TODO: make get_result_type must after coerce type.
- // if type isn't coerced, we need get common type, and then get result type.
- let common_type = temporal_coercion(lhs_type, rhs_type);
- common_type.and_then(|t| mathematics_temporal_result_type(&t, &t))
- }
_ => None,
}
}
/// returns the resulting type of a binary expression evaluating the `op` with the left and right hand types
pub fn get_result_type(
- lhs_type: &DataType,
+ lhs: &DataType,
op: &Operator,
- rhs_type: &DataType,
+ rhs: &DataType,
) -> Result<DataType> {
- if op.is_numerical_operators() && any_decimal(lhs_type, rhs_type) {
- let (coerced_lhs_type, coerced_rhs_type) =
- math_decimal_coercion(lhs_type, rhs_type);
-
- let lhs_type = coerced_lhs_type.unwrap_or(lhs_type.clone());
- let rhs_type = coerced_rhs_type.unwrap_or(rhs_type.clone());
-
- if op.is_numerical_operators() {
- if let Some(result_type) =
- decimal_op_mathematics_type(op, &lhs_type, &rhs_type)
- {
- return Ok(result_type);
- }
- }
- }
- let result = match op {
- Operator::And
- | Operator::Or
- | Operator::Eq
- | Operator::NotEq
- | Operator::Lt
- | Operator::Gt
- | Operator::GtEq
- | Operator::LtEq
- | Operator::RegexMatch
- | Operator::RegexIMatch
- | Operator::RegexNotMatch
- | Operator::RegexNotIMatch
- | Operator::IsDistinctFrom
- | Operator::IsNotDistinctFrom => Some(DataType::Boolean),
- Operator::Plus | Operator::Minus
- if is_datetime(lhs_type) && is_datetime(rhs_type)
- || (is_interval(lhs_type) && is_interval(rhs_type))
- || (is_datetime(lhs_type) && is_interval(rhs_type))
- || (is_interval(lhs_type) && is_datetime(rhs_type)) =>
- {
- mathematics_temporal_result_type(lhs_type, rhs_type)
- }
- // following same with `coerce_types`
- Operator::BitwiseAnd
- | Operator::BitwiseOr
- | Operator::BitwiseXor
- | Operator::BitwiseShiftRight
- | Operator::BitwiseShiftLeft => bitwise_coercion(lhs_type, rhs_type),
- Operator::Plus
- | Operator::Minus
- | Operator::Modulo
- | Operator::Divide
- | Operator::Multiply => mathematics_numerical_coercion(lhs_type, rhs_type),
- Operator::StringConcat => string_concat_coercion(lhs_type, rhs_type),
- };
-
- result.ok_or(DataFusionError::Plan(format!(
- "Unsupported argument types. Can not evaluate {lhs_type:?} {op} {rhs_type:?}"
- )))
+ signature(lhs, op, rhs).map(|sig| sig.ret)
}
-/// Coercion rules for all binary operators. Returns the 'coerce_types'
-/// is returns the type the arguments should be coerced to
-///
-/// Returns None if no suitable type can be found.
-pub fn coerce_types(
- lhs_type: &DataType,
+/// Returns the coerced input types for a binary expression evaluating the `op` with the left and right hand types
+pub fn get_input_types(
+ lhs: &DataType,
op: &Operator,
- rhs_type: &DataType,
-) -> Result<DataType> {
- // This result MUST be compatible with `binary_coerce`
- let result = match op {
- Operator::BitwiseAnd
- | Operator::BitwiseOr
- | Operator::BitwiseXor
- | Operator::BitwiseShiftRight
- | Operator::BitwiseShiftLeft => bitwise_coercion(lhs_type, rhs_type),
- Operator::And | Operator::Or => match (lhs_type, rhs_type) {
- // logical binary boolean operators can only be evaluated in bools or nulls
- (DataType::Boolean, DataType::Boolean)
- | (DataType::Null, DataType::Null)
- | (DataType::Boolean, DataType::Null)
- | (DataType::Null, DataType::Boolean) => Some(DataType::Boolean),
- _ => None,
- },
- // logical comparison operators have their own rules, and always return a boolean
- Operator::Eq
- | Operator::NotEq
- | Operator::Lt
- | Operator::Gt
- | Operator::GtEq
- | Operator::LtEq
- | Operator::IsDistinctFrom
- | Operator::IsNotDistinctFrom => comparison_coercion(lhs_type, rhs_type),
- Operator::Plus | Operator::Minus
- if is_interval(lhs_type) && is_interval(rhs_type) =>
- {
- temporal_coercion(lhs_type, rhs_type)
- }
- Operator::Minus if is_datetime(lhs_type) && is_datetime(rhs_type) => {
- temporal_coercion(lhs_type, rhs_type)
- }
- // for math expressions, the final value of the coercion is also the return type
- // because coercion favours higher information types
- Operator::Plus
- | Operator::Minus
- | Operator::Modulo
- | Operator::Divide
- | Operator::Multiply => mathematics_numerical_coercion(lhs_type, rhs_type),
- Operator::RegexMatch
- | Operator::RegexIMatch
- | Operator::RegexNotMatch
- | Operator::RegexNotIMatch => regex_coercion(lhs_type, rhs_type),
- // "||" operator has its own rules, and always return a string type
- Operator::StringConcat => string_concat_coercion(lhs_type, rhs_type),
- };
-
- // re-write the error message of failed coercions to include the operator's information
- result.ok_or(DataFusionError::Plan(format!("{lhs_type:?} {op} {rhs_type:?} can't be evaluated because there isn't a common type to coerce the types to")))
+ rhs: &DataType,
+) -> Result<(DataType, DataType)> {
+ signature(lhs, op, rhs).map(|sig| (sig.lhs, sig.rhs))
}
/// Coercion rules for mathematics operators between decimal and non-decimal types.
-pub fn math_decimal_coercion(
+fn math_decimal_coercion(
lhs_type: &DataType,
rhs_type: &DataType,
-) -> (Option<DataType>, Option<DataType>) {
+) -> Option<(DataType, DataType)> {
use arrow::datatypes::DataType::*;
- if both_decimal(lhs_type, rhs_type) {
- return (None, None);
- }
-
match (lhs_type, rhs_type) {
- (Null, dec_type @ Decimal128(_, _)) => (Some(dec_type.clone()), None),
- (dec_type @ Decimal128(_, _), Null) => (None, Some(dec_type.clone())),
(Dictionary(key_type, value_type), _) => {
- let (value_type, rhs_type) = math_decimal_coercion(value_type, rhs_type);
- let lhs_type = value_type
- .map(|value_type| Dictionary(key_type.clone(), Box::new(value_type)));
- (lhs_type, rhs_type)
+ let (value_type, rhs_type) = math_decimal_coercion(value_type, rhs_type)?;
+ Some((Dictionary(key_type.clone(), Box::new(value_type)), rhs_type))
}
(_, Dictionary(key_type, value_type)) => {
- let (lhs_type, value_type) = math_decimal_coercion(lhs_type, value_type);
- let rhs_type = value_type
- .map(|value_type| Dictionary(key_type.clone(), Box::new(value_type)));
- (lhs_type, rhs_type)
+ let (lhs_type, value_type) = math_decimal_coercion(lhs_type, value_type)?;
+ Some((lhs_type, Dictionary(key_type.clone(), Box::new(value_type))))
}
- (Decimal128(_, _), Float32 | Float64) => (Some(Float64), Some(Float64)),
- (Float32 | Float64, Decimal128(_, _)) => (Some(Float64), Some(Float64)),
- (Decimal128(_, _), _) => {
- let converted_decimal_type = coerce_numeric_type_to_decimal(rhs_type);
- (None, converted_decimal_type)
+ (Null, dec_type @ Decimal128(_, _)) | (dec_type @ Decimal128(_, _), Null) => {
+ Some((dec_type.clone(), dec_type.clone()))
}
- (_, Decimal128(_, _)) => {
- let converted_decimal_type = coerce_numeric_type_to_decimal(lhs_type);
- (converted_decimal_type, None)
+ (Decimal128(_, _), Decimal128(_, _)) => {
+ Some((lhs_type.clone(), rhs_type.clone()))
}
- _ => (None, None),
+ // Unlike with comparison we don't coerce to a decimal in the case of floating point
+ // numbers, instead falling back to floating point arithmetic instead
+ (Decimal128(_, _), Int8 | Int16 | Int32 | Int64) => {
+ Some((lhs_type.clone(), coerce_numeric_type_to_decimal(rhs_type)?))
+ }
+ (Int8 | Int16 | Int32 | Int64, Decimal128(_, _)) => {
+ Some((coerce_numeric_type_to_decimal(lhs_type)?, rhs_type.clone()))
+ }
+ _ => None,
}
}
/// Returns the output type of applying bitwise operations such as
/// `&`, `|`, or `xor`to arguments of `lhs_type` and `rhs_type`.
-pub(crate) fn bitwise_coercion(
- left_type: &DataType,
- right_type: &DataType,
-) -> Option<DataType> {
+fn bitwise_coercion(left_type: &DataType, right_type: &DataType) -> Option<DataType> {
use arrow::datatypes::DataType::*;
if !both_numeric_or_null_and_numeric(left_type, right_type) {
@@ -289,9 +290,7 @@ pub(crate) fn bitwise_coercion(
}
}
-/// Returns the output type of applying comparison operations such as
-/// `eq`, `not eq`, `lt`, `lteq`, `gt`, and `gteq` to arguments
-/// of `lhs_type` and `rhs_type`.
+/// Coerce `lhs_type` and `rhs_type` to a common type for the purposes of a comparison operation
pub fn comparison_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<DataType> {
if lhs_type == rhs_type {
// same type => equality is possible
@@ -303,11 +302,11 @@ pub fn comparison_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<D
.or_else(|| string_coercion(lhs_type, rhs_type))
.or_else(|| null_coercion(lhs_type, rhs_type))
.or_else(|| string_numeric_coercion(lhs_type, rhs_type))
+ .or_else(|| string_temporal_coercion(lhs_type, rhs_type))
}
-/// Returns the output type of applying numeric operations such as `=`
-/// to arguments `lhs_type` and `rhs_type` if one is numeric and one
-/// is `Utf8`/`LargeUtf8`.
+/// Coerce `lhs_type` and `rhs_type` to a common type for the purposes of a comparison operation
+/// where one is numeric and one is `Utf8`/`LargeUtf8`.
fn string_numeric_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<DataType> {
use arrow::datatypes::DataType::*;
match (lhs_type, rhs_type) {
@@ -319,8 +318,48 @@ fn string_numeric_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<D
}
}
-/// Returns the output type of applying numeric operations such as `=`
-/// to arguments `lhs_type` and `rhs_type` if both are numeric
+/// Coerce `lhs_type` and `rhs_type` to a common type for the purposes of a comparison operation
+/// where one is temporal and one is `Utf8`/`LargeUtf8`.
+///
+/// Note this cannot be performed in case of arithmetic as there is insufficient information
+/// to correctly determine the type of argument. Consider
+///
+/// ```sql
+/// timestamp > now() - '1 month'
+/// interval > now() - '1970-01-2021'
+/// ```
+///
+/// In the absence of a full type inference system, we can't determine the correct type
+/// to parse the string argument
+fn string_temporal_coercion(
+ lhs_type: &DataType,
+ rhs_type: &DataType,
+) -> Option<DataType> {
+ use arrow::datatypes::DataType::*;
+ match (lhs_type, rhs_type) {
+ (Utf8, Date32) | (Date32, Utf8) => Some(Date32),
+ (Utf8, Date64) | (Date64, Utf8) => Some(Date64),
+ (Utf8, Time32(unit)) | (Time32(unit), Utf8) => {
+ match is_time_with_valid_unit(Time32(unit.clone())) {
+ false => None,
+ true => Some(Time32(unit.clone())),
+ }
+ }
+ (Utf8, Time64(unit)) | (Time64(unit), Utf8) => {
+ match is_time_with_valid_unit(Time64(unit.clone())) {
+ false => None,
+ true => Some(Time64(unit.clone())),
+ }
+ }
+ (Timestamp(_, tz), Utf8) | (Utf8, Timestamp(_, tz)) => {
+ Some(Timestamp(TimeUnit::Nanosecond, tz.clone()))
+ }
+ _ => None,
+ }
+}
+
+/// Coerce `lhs_type` and `rhs_type` to a common type for the purposes of a comparison operation
+/// where one both are numeric
fn comparison_binary_numeric_coercion(
lhs_type: &DataType,
rhs_type: &DataType,
@@ -338,7 +377,7 @@ fn comparison_binary_numeric_coercion(
// these are ordered from most informative to least informative so
// that the coercion does not lose information via truncation
match (lhs_type, rhs_type) {
- // support decimal data type for comparison operation
+ // Prefer decimal data type over floating point for comparison operation
(Decimal128(_, _), Decimal128(_, _)) => {
get_wider_decimal_type(lhs_type, rhs_type)
}
@@ -381,26 +420,14 @@ fn comparison_binary_numeric_coercion(
}
}
-/// Returns the output type of applying numeric operations such as `=`
-/// to a decimal type `decimal_type` and `other_type`
+/// Coerce `lhs_type` and `rhs_type` to a common type for the purposes of
+/// a comparison operation where one is a decimal
fn get_comparison_common_decimal_type(
decimal_type: &DataType,
other_type: &DataType,
) -> Option<DataType> {
use arrow::datatypes::DataType::*;
- let other_decimal_type = &match other_type {
- // This conversion rule is from spark
- // https://github.com/apache/spark/blob/1c81ad20296d34f137238dadd67cc6ae405944eb/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala#L127
- Int8 => Decimal128(3, 0),
- Int16 => Decimal128(5, 0),
- Int32 => Decimal128(10, 0),
- Int64 => Decimal128(20, 0),
- Float32 => Decimal128(14, 7),
- Float64 => Decimal128(30, 15),
- _ => {
- return None;
- }
- };
+ let other_decimal_type = coerce_numeric_type_to_decimal(other_type)?;
match (decimal_type, &other_decimal_type) {
(d1 @ Decimal128(_, _), d2 @ Decimal128(_, _)) => get_wider_decimal_type(d1, d2),
_ => None,
@@ -430,6 +457,8 @@ fn get_wider_decimal_type(
/// Now, we just support the signed integer type and floating-point type.
fn coerce_numeric_type_to_decimal(numeric_type: &DataType) -> Option<DataType> {
use arrow::datatypes::DataType::*;
+ // This conversion rule is from spark
+ // https://github.com/apache/spark/blob/1c81ad20296d34f137238dadd67cc6ae405944eb/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala#L127
match numeric_type {
Int8 => Some(Decimal128(3, 0)),
Int16 => Some(Decimal128(5, 0)),
@@ -499,6 +528,7 @@ pub fn coercion_decimal_mathematics_type(
left_decimal_type: &DataType,
right_decimal_type: &DataType,
) -> Option<DataType> {
+ // TODO: Move this logic into kernel implementations
use arrow::datatypes::DataType::*;
match (left_decimal_type, right_decimal_type) {
// The promotion rule from spark
@@ -518,7 +548,7 @@ pub fn coercion_decimal_mathematics_type(
}
}
-/// Returns the output type of applying mathematics operations on decimal types.
+/// Returns the output type of applying mathematics operations on two decimal types.
/// The rule is from spark. Note that this is different to the coerced type applied
/// to two sides of the arithmetic operation.
pub fn decimal_op_mathematics_type(
@@ -605,29 +635,6 @@ fn both_numeric_or_null_and_numeric(lhs_type: &DataType, rhs_type: &DataType) ->
}
}
-/// Determine if at least of one of lhs and rhs is decimal, and the other must be NULL or decimal
-fn both_decimal(lhs_type: &DataType, rhs_type: &DataType) -> bool {
- use arrow::datatypes::DataType::*;
- match (lhs_type, rhs_type) {
- (_, Null) => is_decimal(lhs_type),
- (Null, _) => is_decimal(rhs_type),
- (Decimal128(_, _), Decimal128(_, _)) => true,
- (Dictionary(_, value_type), _) => is_decimal(value_type) && is_decimal(rhs_type),
- (_, Dictionary(_, value_type)) => is_decimal(lhs_type) && is_decimal(value_type),
- _ => false,
- }
-}
-
-/// Determine if at least of one of lhs and rhs is decimal
-pub fn any_decimal(lhs_type: &DataType, rhs_type: &DataType) -> bool {
- use arrow::datatypes::DataType::*;
- match (lhs_type, rhs_type) {
- (Dictionary(_, value_type), _) => is_decimal(value_type) || is_decimal(rhs_type),
- (_, Dictionary(_, value_type)) => is_decimal(lhs_type) || is_decimal(value_type),
- (_, _) => is_decimal(lhs_type) || is_decimal(rhs_type),
- }
-}
-
/// Coercion rules for Dictionaries: the type that both lhs and rhs
/// can be casted to for the purpose of a computation.
///
@@ -743,30 +750,10 @@ fn temporal_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<DataTyp
use arrow::datatypes::IntervalUnit::*;
use arrow::datatypes::TimeUnit::*;
- if lhs_type == rhs_type {
- return Some(lhs_type.clone());
- }
match (lhs_type, rhs_type) {
// interval +/-
(Interval(_), Interval(_)) => Some(Interval(MonthDayNano)),
(Date64, Date32) | (Date32, Date64) => Some(Date64),
- (Utf8, Date32) | (Date32, Utf8) => Some(Date32),
- (Utf8, Date64) | (Date64, Utf8) => Some(Date64),
- (Utf8, Time32(unit)) | (Time32(unit), Utf8) => {
- match is_time_with_valid_unit(Time32(unit.clone())) {
- false => None,
- true => Some(Time32(unit.clone())),
- }
- }
- (Utf8, Time64(unit)) | (Time64(unit), Utf8) => {
- match is_time_with_valid_unit(Time64(unit.clone())) {
- false => None,
- true => Some(Time64(unit.clone())),
- }
- }
- (Timestamp(_, tz), Utf8) | (Utf8, Timestamp(_, tz)) => {
- Some(Timestamp(Nanosecond, tz.clone()))
- }
(Timestamp(_, None), Date32) | (Date32, Timestamp(_, None)) => {
Some(Timestamp(Nanosecond, None))
}
@@ -832,7 +819,6 @@ fn null_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<DataType> {
mod tests {
use arrow::datatypes::DataType;
- use datafusion_common::assert_contains;
use datafusion_common::DataFusionError;
use datafusion_common::Result;
@@ -843,10 +829,13 @@ mod tests {
#[test]
fn test_coercion_error() -> Result<()> {
let result_type =
- coerce_types(&DataType::Float32, &Operator::Plus, &DataType::Utf8);
+ get_input_types(&DataType::Float32, &Operator::Plus, &DataType::Utf8);
if let Err(DataFusionError::Plan(e)) = result_type {
- assert_eq!(e, "Float32 + Utf8 can't be evaluated because there isn't a common type to coerce the types to");
+ assert_eq!(
+ e,
+ "Cannot coerce arithmetic expression Float32 + Utf8 to valid types"
+ );
Ok(())
} else {
Err(DataFusionError::Internal(
@@ -891,12 +880,14 @@ mod tests {
for (i, input_type) in input_types.iter().enumerate() {
let expect_type = &result_types[i];
for op in comparison_op_types {
- let result_type = coerce_types(&input_decimal, &op, input_type)?;
- assert_eq!(expect_type, &result_type);
+ let (lhs, rhs) = get_input_types(&input_decimal, &op, input_type)?;
+ assert_eq!(expect_type, &lhs);
+ assert_eq!(expect_type, &rhs);
}
}
// negative test
- let result_type = coerce_types(&input_decimal, &Operator::Eq, &DataType::Boolean);
+ let result_type =
+ get_input_types(&input_decimal, &Operator::Eq, &DataType::Boolean);
assert!(result_type.is_err());
Ok(())
}
@@ -1017,24 +1008,27 @@ mod tests {
macro_rules! test_coercion_binary_rule {
($A_TYPE:expr, $B_TYPE:expr, $OP:expr, $C_TYPE:expr) => {{
- let result = coerce_types(&$A_TYPE, &$OP, &$B_TYPE)?;
- assert_eq!(result, $C_TYPE);
+ let (lhs, rhs) = get_input_types(&$A_TYPE, &$OP, &$B_TYPE)?;
+ assert_eq!(lhs, $C_TYPE);
+ assert_eq!(rhs, $C_TYPE);
}};
}
#[test]
fn test_date_timestamp_arithmetic_error() -> Result<()> {
- let common_type = coerce_types(
+ let (lhs, rhs) = get_input_types(
&DataType::Timestamp(TimeUnit::Nanosecond, None),
&Operator::Minus,
&DataType::Timestamp(TimeUnit::Millisecond, None),
)?;
- assert_eq!(common_type.to_string(), "Timestamp(Millisecond, None)");
+ assert_eq!(lhs.to_string(), "Timestamp(Millisecond, None)");
+ assert_eq!(rhs.to_string(), "Timestamp(Millisecond, None)");
- let err = coerce_types(&DataType::Date32, &Operator::Plus, &DataType::Date64)
- .unwrap_err()
- .to_string();
- assert_contains!(&err, "Date32 + Date64 can't be evaluated because there isn't a common type to coerce the types to");
+ let (lhs, rhs) =
+ get_input_types(&DataType::Date32, &Operator::Plus, &DataType::Date64)
+ .unwrap();
+ assert_eq!(lhs.to_string(), "Date64");
+ assert_eq!(rhs.to_string(), "Date64");
Ok(())
}
@@ -1234,18 +1228,15 @@ mod tests {
lhs_type: DataType,
rhs_type: DataType,
mathematics_op: Operator,
- expected_lhs_type: Option<DataType>,
- expected_rhs_type: Option<DataType>,
+ expected_lhs_type: DataType,
+ expected_rhs_type: DataType,
expected_coerced_type: Option<DataType>,
expected_output_type: DataType,
) {
// The coerced types for lhs and rhs, if any of them is not decimal
- let (l, r) = math_decimal_coercion(&lhs_type, &rhs_type);
- assert_eq!(l, expected_lhs_type);
- assert_eq!(r, expected_rhs_type);
-
- let lhs_type = l.unwrap_or(lhs_type);
- let rhs_type = r.unwrap_or(rhs_type);
+ let (lhs_type, rhs_type) = math_decimal_coercion(&lhs_type, &rhs_type).unwrap();
+ assert_eq!(lhs_type, expected_lhs_type);
+ assert_eq!(rhs_type, expected_rhs_type);
// The coerced type of decimal math expression, applied during expression evaluation
let coerced_type =
@@ -1264,8 +1255,8 @@ mod tests {
DataType::Decimal128(10, 2),
DataType::Decimal128(10, 2),
Operator::Plus,
- None,
- None,
+ DataType::Decimal128(10, 2),
+ DataType::Decimal128(10, 2),
Some(DataType::Decimal128(11, 2)),
DataType::Decimal128(11, 2),
);
@@ -1274,8 +1265,8 @@ mod tests {
DataType::Int32,
DataType::Decimal128(10, 2),
Operator::Plus,
- Some(DataType::Decimal128(10, 0)),
- None,
+ DataType::Decimal128(10, 0),
+ DataType::Decimal128(10, 2),
Some(DataType::Decimal128(13, 2)),
DataType::Decimal128(13, 2),
);
@@ -1284,8 +1275,8 @@ mod tests {
DataType::Int32,
DataType::Decimal128(10, 2),
Operator::Minus,
- Some(DataType::Decimal128(10, 0)),
- None,
+ DataType::Decimal128(10, 0),
+ DataType::Decimal128(10, 2),
Some(DataType::Decimal128(13, 2)),
DataType::Decimal128(13, 2),
);
@@ -1294,8 +1285,8 @@ mod tests {
DataType::Int32,
DataType::Decimal128(10, 2),
Operator::Multiply,
- Some(DataType::Decimal128(10, 0)),
- None,
+ DataType::Decimal128(10, 0),
+ DataType::Decimal128(10, 2),
None,
DataType::Decimal128(21, 2),
);
@@ -1304,8 +1295,8 @@ mod tests {
DataType::Int32,
DataType::Decimal128(10, 2),
Operator::Divide,
- Some(DataType::Decimal128(10, 0)),
- None,
+ DataType::Decimal128(10, 0),
+ DataType::Decimal128(10, 2),
Some(DataType::Decimal128(12, 2)),
DataType::Decimal128(23, 11),
);
@@ -1314,8 +1305,8 @@ mod tests {
DataType::Int32,
DataType::Decimal128(10, 2),
Operator::Modulo,
- Some(DataType::Decimal128(10, 0)),
- None,
+ DataType::Decimal128(10, 0),
+ DataType::Decimal128(10, 2),
Some(DataType::Decimal128(12, 2)),
DataType::Decimal128(10, 2),
);
diff --git a/datafusion/optimizer/src/analyzer/type_coercion.rs b/datafusion/optimizer/src/analyzer/type_coercion.rs
index 412abbfae6..61153c0d36 100644
--- a/datafusion/optimizer/src/analyzer/type_coercion.rs
+++ b/datafusion/optimizer/src/analyzer/type_coercion.rs
@@ -32,7 +32,7 @@ use datafusion_expr::expr_rewriter::rewrite_preserving_name;
use datafusion_expr::expr_schema::cast_subquery;
use datafusion_expr::logical_plan::Subquery;
use datafusion_expr::type_coercion::binary::{
- any_decimal, coerce_types, comparison_coercion, like_coercion, math_decimal_coercion,
+ comparison_coercion, get_input_types, like_coercion,
};
use datafusion_expr::type_coercion::functions::data_types;
use datafusion_expr::type_coercion::other::{
@@ -230,72 +230,18 @@ impl TreeNodeRewriter for TypeCoercionRewriter {
let expr = Expr::ILike(Like::new(negated, expr, pattern, escape_char));
Ok(expr)
}
- Expr::BinaryExpr(BinaryExpr {
- ref left,
- op,
- ref right,
- }) => {
- // this is a workaround for https://github.com/apache/arrow-datafusion/issues/3419
- let left_type = left.get_type(&self.schema)?;
- let right_type = right.get_type(&self.schema)?;
- match (&left_type, &right_type) {
- // Handle some case about Interval.
- (
- DataType::Date32 | DataType::Date64 | DataType::Timestamp(_, _),
- &DataType::Interval(_),
- ) if matches!(op, Operator::Plus | Operator::Minus) => Ok(expr),
- (
- &DataType::Interval(_),
- DataType::Date32 | DataType::Date64 | DataType::Timestamp(_, _),
- ) if matches!(op, Operator::Plus) => Ok(expr),
- (DataType::Timestamp(_, _), DataType::Timestamp(_, _))
- if op.is_numerical_operators() =>
- {
- if matches!(op, Operator::Minus) {
- Ok(expr)
- } else {
- Err(DataFusionError::Internal(format!(
- "Unsupported operation {op:?} between {left_type:?} and {right_type:?}"
- )))
- }
- }
- // For numerical operations between decimals, we don't coerce the types.
- // But if only one of the operands is decimal, we cast the other operand to decimal
- // if the other operand is integer. If the other operand is float, we cast the
- // decimal operand to float.
- (lhs_type, rhs_type)
- if op.is_numerical_operators()
- && any_decimal(lhs_type, rhs_type) =>
- {
- let (coerced_lhs_type, coerced_rhs_type) =
- math_decimal_coercion(lhs_type, rhs_type);
- let new_left = if let Some(lhs_type) = coerced_lhs_type {
- left.clone().cast_to(&lhs_type, &self.schema)?
- } else {
- left.as_ref().clone()
- };
- let new_right = if let Some(rhs_type) = coerced_rhs_type {
- right.clone().cast_to(&rhs_type, &self.schema)?
- } else {
- right.as_ref().clone()
- };
- let expr = Expr::BinaryExpr(BinaryExpr::new(
- Box::new(new_left),
- op,
- Box::new(new_right),
- ));
- Ok(expr)
- }
- _ => {
- let common_type = coerce_types(&left_type, &op, &right_type)?;
- let expr = Expr::BinaryExpr(BinaryExpr::new(
- Box::new(left.clone().cast_to(&common_type, &self.schema)?),
- op,
- Box::new(right.clone().cast_to(&common_type, &self.schema)?),
- ));
- Ok(expr)
- }
- }
+ Expr::BinaryExpr(BinaryExpr { left, op, right }) => {
+ let (left_type, right_type) = get_input_types(
+ &left.get_type(&self.schema)?,
+ &op,
+ &right.get_type(&self.schema)?,
+ )?;
+
+ Ok(Expr::BinaryExpr(BinaryExpr::new(
+ Box::new(left.cast_to(&left_type, &self.schema)?),
+ op,
+ Box::new(right.cast_to(&right_type, &self.schema)?),
+ )))
}
Expr::Between(Between {
expr,
@@ -566,7 +512,7 @@ fn coerce_window_frame(
// The above op will be rewrite to the binary op when creating the physical op.
fn get_casted_expr_for_bool_op(expr: &Expr, schema: &DFSchemaRef) -> Result<Expr> {
let left_type = expr.get_type(schema)?;
- coerce_types(&left_type, &Operator::IsDistinctFrom, &DataType::Boolean)?;
+ get_input_types(&left_type, &Operator::IsDistinctFrom, &DataType::Boolean)?;
expr.clone().cast_to(&DataType::Boolean, schema)
}
@@ -1108,9 +1054,9 @@ mod test {
let empty = empty_with_type(DataType::Int64);
let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?);
- let err = assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), &plan, "");
- assert!(err.is_err());
- assert!(err.unwrap_err().to_string().contains("Int64 IS DISTINCT FROM Boolean can't be evaluated because there isn't a common type to coerce the types to"));
+ let ret = assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), &plan, "");
+ let err = ret.unwrap_err().to_string();
+ assert!(err.contains("Cannot infer common argument type for comparison operation Int64 IS DISTINCT FROM Boolean"), "{err}");
// is not true
let expr = col("a").is_not_true();
@@ -1210,9 +1156,9 @@ mod test {
let empty = empty_with_type(DataType::Utf8);
let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?);
- let err = assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), &plan, expected);
- assert!(err.is_err());
- assert!(err.unwrap_err().to_string().contains("Utf8 IS DISTINCT FROM Boolean can't be evaluated because there isn't a common type to coerce the types to"));
+ let ret = assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), &plan, expected);
+ let err = ret.unwrap_err().to_string();
+ assert!(err.contains("Cannot infer common argument type for comparison operation Utf8 IS DISTINCT FROM Boolean"), "{err}");
// is not unknown
let expr = col("a").is_not_unknown();
diff --git a/datafusion/physical-expr/src/expressions/binary.rs b/datafusion/physical-expr/src/expressions/binary.rs
index e3bbefbcd3..c764692da3 100644
--- a/datafusion/physical-expr/src/expressions/binary.rs
+++ b/datafusion/physical-expr/src/expressions/binary.rs
@@ -1343,7 +1343,7 @@ mod tests {
ArrowNumericType, Decimal128Type, Field, Int32Type, SchemaRef,
};
use datafusion_common::{ColumnStatistics, Result, Statistics};
- use datafusion_expr::type_coercion::binary::{coerce_types, math_decimal_coercion};
+ use datafusion_expr::type_coercion::binary::get_input_types;
// Create a binary expression without coercion. Used here when we do not want to coerce the expressions
// to valid types. Usage can result in an execution (after plan) error.
@@ -1447,10 +1447,10 @@ mod tests {
]);
let a = $A_ARRAY::from($A_VEC);
let b = $B_ARRAY::from($B_VEC);
- let common_type = coerce_types(&$A_TYPE, &$OP, &$B_TYPE)?;
+ let (lhs, rhs) = get_input_types(&$A_TYPE, &$OP, &$B_TYPE)?;
- let left = try_cast(col("a", &schema)?, &schema, common_type.clone())?;
- let right = try_cast(col("b", &schema)?, &schema, common_type)?;
+ let left = try_cast(col("a", &schema)?, &schema, lhs)?;
+ let right = try_cast(col("b", &schema)?, &schema, rhs)?;
// verify that we can construct the expression
let expression = binary(left, $OP, right, &schema)?;
@@ -2964,10 +2964,10 @@ mod tests {
) -> Result<()> {
let left_type = left.data_type();
let right_type = right.data_type();
- let common_type = coerce_types(left_type, &op, right_type)?;
+ let (lhs, rhs) = get_input_types(left_type, &op, right_type)?;
- let left_expr = try_cast(col("a", schema)?, schema, common_type.clone())?;
- let right_expr = try_cast(col("b", schema)?, schema, common_type)?;
+ let left_expr = try_cast(col("a", schema)?, schema, lhs)?;
+ let right_expr = try_cast(col("b", schema)?, schema, rhs)?;
let arithmetic_op = binary_simple(left_expr, op, right_expr, schema);
let data: Vec<ArrayRef> = vec![left.clone(), right.clone()];
let batch = RecordBatch::try_new(schema.clone(), data)?;
@@ -2986,17 +2986,10 @@ mod tests {
expected: &BooleanArray,
) -> Result<()> {
let scalar = lit(scalar.clone());
- let op_type = coerce_types(&scalar.data_type(schema)?, &op, arr.data_type())?;
- let left_expr = if op_type.eq(&scalar.data_type(schema)?) {
- scalar
- } else {
- try_cast(scalar, schema, op_type.clone())?
- };
- let right_expr = if op_type.eq(arr.data_type()) {
- col("a", schema)?
- } else {
- try_cast(col("a", schema)?, schema, op_type)?
- };
+ let (lhs, rhs) =
+ get_input_types(&scalar.data_type(schema)?, &op, arr.data_type())?;
+ let left_expr = try_cast(scalar, schema, lhs)?;
+ let right_expr = try_cast(col("a", schema)?, schema, rhs)?;
let arithmetic_op = binary_simple(left_expr, op, right_expr, schema);
let batch = RecordBatch::try_new(Arc::clone(schema), vec![Arc::clone(arr)])?;
@@ -3015,17 +3008,10 @@ mod tests {
expected: &BooleanArray,
) -> Result<()> {
let scalar = lit(scalar.clone());
- let op_type = coerce_types(arr.data_type(), &op, &scalar.data_type(schema)?)?;
- let right_expr = if op_type.eq(&scalar.data_type(schema)?) {
- scalar
- } else {
- try_cast(scalar, schema, op_type.clone())?
- };
- let left_expr = if op_type.eq(arr.data_type()) {
- col("a", schema)?
- } else {
- try_cast(col("a", schema)?, schema, op_type)?
- };
+ let (lhs, rhs) =
+ get_input_types(arr.data_type(), &op, &scalar.data_type(schema)?)?;
+ let left_expr = try_cast(col("a", schema)?, schema, lhs)?;
+ let right_expr = try_cast(scalar, schema, rhs)?;
let arithmetic_op = binary_simple(left_expr, op, right_expr, schema);
let batch = RecordBatch::try_new(Arc::clone(schema), vec![Arc::clone(arr)])?;
@@ -4077,26 +4063,11 @@ mod tests {
op: Operator,
expected: ArrayRef,
) -> Result<()> {
- let (lhs_op_type, rhs_op_type) =
- math_decimal_coercion(left.data_type(), right.data_type());
+ let (lhs_type, rhs_type) =
+ get_input_types(left.data_type(), &op, right.data_type()).unwrap();
- let (left_expr, lhs_type) = if let Some(lhs_op_type) = lhs_op_type {
- (
- try_cast(col("a", schema)?, schema, lhs_op_type.clone())?,
- lhs_op_type,
- )
- } else {
- (col("a", schema)?, left.data_type().clone())
- };
-
- let (right_expr, rhs_type) = if let Some(rhs_op_type) = rhs_op_type {
- (
- try_cast(col("b", schema)?, schema, rhs_op_type.clone())?,
- rhs_op_type,
- )
- } else {
- (col("b", schema)?, right.data_type().clone())
- };
+ let left_expr = try_cast(col("a", schema)?, schema, lhs_type.clone())?;
+ let right_expr = try_cast(col("b", schema)?, schema, rhs_type.clone())?;
let coerced_schema = Schema::new(vec![
Field::new(