You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by tu...@apache.org on 2023/06/28 08:29:16 UTC

[arrow-datafusion] branch main updated: Cleanup type coercion (#3419) (#6778)

This is an automated email from the ASF dual-hosted git repository.

tustvold pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new b9ecfc517b Cleanup type coercion (#3419) (#6778)
b9ecfc517b is described below

commit b9ecfc517bbf5121af1c64f415568322aca1f290
Author: Raphael Taylor-Davies <17...@users.noreply.github.com>
AuthorDate: Wed Jun 28 09:29:10 2023 +0100

    Cleanup type coercion (#3419) (#6778)
    
    * Cleanup type coercion (#3419)
    
    * Further fixes
    
    * Tweak doc
    
    * Review feedback
---
 datafusion/core/src/physical_planner.rs            |  19 +-
 .../core/tests/sqllogictests/test_files/dates.slt  |   2 +-
 .../tests/sqllogictests/test_files/interval.slt    |  34 +-
 .../tests/sqllogictests/test_files/timestamps.slt  |   2 +-
 .../sqllogictests/test_files/type_coercion.slt     |   8 +-
 datafusion/expr/src/type_coercion/binary.rs        | 523 ++++++++++-----------
 datafusion/optimizer/src/analyzer/type_coercion.rs |  94 +---
 datafusion/physical-expr/src/expressions/binary.rs |  67 +--
 8 files changed, 339 insertions(+), 410 deletions(-)

diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs
index 75566208e3..da2f396e8e 100644
--- a/datafusion/core/src/physical_planner.rs
+++ b/datafusion/core/src/physical_planner.rs
@@ -1956,7 +1956,7 @@ mod tests {
     use fmt::Debug;
     use std::collections::HashMap;
     use std::convert::TryFrom;
-    use std::ops::Not;
+    use std::ops::{BitAnd, Not};
     use std::{any::Any, fmt};
 
     fn make_session_state() -> SessionState {
@@ -2140,18 +2140,17 @@ mod tests {
     async fn errors() -> Result<()> {
         let bool_expr = col("c1").eq(col("c1"));
         let cases = vec![
-            // utf8 AND utf8
-            col("c1").and(col("c1")),
+            // utf8 = utf8
+            col("c1").eq(col("c1")),
             // u8 AND u8
-            col("c3").and(col("c3")),
-            // utf8 = bool
-            col("c1").eq(bool_expr.clone()),
-            // u32 AND bool
-            col("c2").and(bool_expr),
+            col("c3").bitand(col("c3")),
+            // utf8 = u8
+            col("c1").eq(col("c3")),
+            // bool AND bool
+            bool_expr.clone().and(bool_expr),
         ];
         for case in cases {
-            let logical_plan = test_csv_scan().await?.project(vec![case.clone()]);
-            assert!(logical_plan.is_ok());
+            test_csv_scan().await?.project(vec![case.clone()]).unwrap();
         }
         Ok(())
     }
diff --git a/datafusion/core/tests/sqllogictests/test_files/dates.slt b/datafusion/core/tests/sqllogictests/test_files/dates.slt
index 5b76739e95..c35f16bc03 100644
--- a/datafusion/core/tests/sqllogictests/test_files/dates.slt
+++ b/datafusion/core/tests/sqllogictests/test_files/dates.slt
@@ -85,7 +85,7 @@ g
 h
 
 ## Plan error when compare Utf8 and timestamp in where clause
-statement error Error during planning: Timestamp\(Nanosecond, Some\("\+00:00"\)\) \+ Utf8 can't be evaluated because there isn't a common type to coerce the types to
+statement error DataFusion error: type_coercion\ncaused by\nError during planning: Cannot coerce arithmetic expression Timestamp\(Nanosecond, Some\("\+00:00"\)\) \+ Utf8 to valid types
 select i_item_desc from test
 where d3_date > now() + '5 days';
 
diff --git a/datafusion/core/tests/sqllogictests/test_files/interval.slt b/datafusion/core/tests/sqllogictests/test_files/interval.slt
index 889e2759b2..9dd56c4636 100644
--- a/datafusion/core/tests/sqllogictests/test_files/interval.slt
+++ b/datafusion/core/tests/sqllogictests/test_files/interval.slt
@@ -430,13 +430,15 @@ select '1 month'::interval + '1980-01-01T12:00:00'::timestamp;
 ----
 1980-02-01T12:00:00
 
-# Exected error: interval (scalar) - date / timestamp (scalar)
-
-query error DataFusion error: type_coercion\ncaused by\nError during planning: Interval\(MonthDayNano\) \- Date32 can't be evaluated because there isn't a common type to coerce the types to
+query D
 select '1 month'::interval - '1980-01-01'::date;
+----
+1979-12-01
 
-query error DataFusion error: type_coercion\ncaused by\nError during planning: Interval\(MonthDayNano\) \- Timestamp\(Nanosecond, None\) can't be evaluated because there isn't a common type to coerce the types to
+query P
 select '1 month'::interval - '1980-01-01T12:00:00'::timestamp;
+----
+1979-12-01T12:00:00
 
 # interval (array) + date / timestamp (array)
 query D
@@ -454,11 +456,19 @@ select i + ts from t;
 2000-02-01T00:01:00
 
 # expected error interval (array) - date / timestamp (array)
-query error DataFusion error: type_coercion\ncaused by\nError during planning: Interval\(MonthDayNano\) \- Date32 can't be evaluated because there isn't a common type to coerce the types to
+query D
 select i - d from t;
+----
+1979-12-01
+1990-09-30
+1980-01-02
 
-query error DataFusion error: type_coercion\ncaused by\nError during planning: Interval\(MonthDayNano\) \- Timestamp\(Nanosecond, None\) can't be evaluated because there isn't a common type to coerce the types to
+query P
 select i - ts from t;
+----
+1999-12-01T00:00:00
+1999-12-31T12:11:10
+2000-01-31T23:59:00
 
 
 # interval (scalar) + date / timestamp (array)
@@ -477,11 +487,19 @@ select '1 month'::interval + ts from t;
 2000-03-01T00:00:00
 
 # expected error interval (scalar) - date / timestamp (array)
-query error DataFusion error: type_coercion\ncaused by\nError during planning: Interval\(MonthDayNano\) \- Date32 can't be evaluated because there isn't a common type to coerce the types to
+query D
 select '1 month'::interval - d from t;
+----
+1979-12-01
+1990-09-01
+1979-12-02
 
-query error DataFusion error: type_coercion\ncaused by\nError during planning: Interval\(MonthDayNano\) \- Timestamp\(Nanosecond, None\) can't be evaluated because there isn't a common type to coerce the types to
+query P
 select '1 month'::interval - ts from t;
+----
+1999-12-01T00:00:00
+1999-12-01T12:11:10
+2000-01-01T00:00:00
 
 # interval + date
 query D
diff --git a/datafusion/core/tests/sqllogictests/test_files/timestamps.slt b/datafusion/core/tests/sqllogictests/test_files/timestamps.slt
index 3ba7c38f16..5250ce2399 100644
--- a/datafusion/core/tests/sqllogictests/test_files/timestamps.slt
+++ b/datafusion/core/tests/sqllogictests/test_files/timestamps.slt
@@ -1233,7 +1233,7 @@ SELECT ts1 + i FROM foo;
 2003-07-12T01:31:15.000123463
 
 # Timestamp + Timestamp => error
-query error DataFusion error: type_coercion\ncaused by\nInternal error: Unsupported operation Plus between Timestamp\(Nanosecond, None\) and Timestamp\(Nanosecond, None\)\. This was likely caused by a bug in DataFusion's code and we would welcome that you file an bug report in our issue tracker
+query error DataFusion error: Arrow error: Cast error: Cannot perform arithmetic operation between array of type Timestamp\(Nanosecond, None\) and array of type Timestamp\(Nanosecond, None\)
 SELECT ts1 + ts2
 FROM foo;
 
diff --git a/datafusion/core/tests/sqllogictests/test_files/type_coercion.slt b/datafusion/core/tests/sqllogictests/test_files/type_coercion.slt
index 9aced0a3fd..8b329df0c1 100644
--- a/datafusion/core/tests/sqllogictests/test_files/type_coercion.slt
+++ b/datafusion/core/tests/sqllogictests/test_files/type_coercion.slt
@@ -43,9 +43,13 @@ SELECT '2023-05-01 12:30:00'::timestamp - interval '1 month';
 2023-04-01T12:30:00
 
 # interval - date
-query error DataFusion error: type_coercion
+query D
 select interval '1 month' - '2023-05-01'::date;
+----
+2023-04-01
 
 # interval - timestamp
-query error DataFusion error: type_coercion
+query P
 SELECT interval '1 month' - '2023-05-01 12:30:00'::timestamp;
+----
+2023-04-01T12:30:00
diff --git a/datafusion/expr/src/type_coercion/binary.rs b/datafusion/expr/src/type_coercion/binary.rs
index 7c9179b2f3..64ebf8b559 100644
--- a/datafusion/expr/src/type_coercion/binary.rs
+++ b/datafusion/expr/src/type_coercion/binary.rs
@@ -25,9 +25,144 @@ use arrow::datatypes::{
 use datafusion_common::DataFusionError;
 use datafusion_common::Result;
 
-use crate::type_coercion::{is_datetime, is_decimal, is_interval, is_numeric};
+use crate::type_coercion::is_numeric;
 use crate::Operator;
 
+/// The type signature of an instantiation of binary expression
+struct Signature {
+    /// The type to coerce the left argument to
+    lhs: DataType,
+    /// The type to coerce the right argument to
+    rhs: DataType,
+    /// The return type of the expression
+    ret: DataType,
+}
+
+impl Signature {
+    /// A signature where the inputs are the same type as the output
+    fn uniform(t: DataType) -> Self {
+        Self {
+            lhs: t.clone(),
+            rhs: t.clone(),
+            ret: t,
+        }
+    }
+
+    /// A signature where the inputs are the same type with a boolean output
+    fn comparison(t: DataType) -> Self {
+        Self {
+            lhs: t.clone(),
+            rhs: t,
+            ret: DataType::Boolean,
+        }
+    }
+}
+
+/// Returns a [`Signature`] for applying `op` to arguments of type `lhs` and `rhs`
+fn signature(lhs: &DataType, op: &Operator, rhs: &DataType) -> Result<Signature> {
+    match op {
+        Operator::Eq |
+        Operator::NotEq |
+        Operator::Lt |
+        Operator::LtEq |
+        Operator::Gt |
+        Operator::GtEq |
+        Operator::IsDistinctFrom |
+        Operator::IsNotDistinctFrom => {
+            comparison_coercion(lhs, rhs).map(Signature::comparison).ok_or_else(|| {
+                DataFusionError::Plan(format!(
+                    "Cannot infer common argument type for comparison operation {lhs} {op} {rhs}"
+                ))
+            })
+        }
+        Operator::And | Operator::Or => match (lhs, rhs) {
+            // logical binary boolean operators can only be evaluated in bools or nulls
+            (DataType::Boolean, DataType::Boolean)
+            | (DataType::Null, DataType::Null)
+            | (DataType::Boolean, DataType::Null)
+            | (DataType::Null, DataType::Boolean) => Ok(Signature::uniform(DataType::Boolean)),
+            _ => Err(DataFusionError::Plan(format!(
+                "Cannot infer common argument type for logical boolean operation {lhs} {op} {rhs}"
+            ))),
+        },
+        Operator::RegexMatch |
+        Operator::RegexIMatch |
+        Operator::RegexNotMatch |
+        Operator::RegexNotIMatch => {
+            regex_coercion(lhs, rhs).map(Signature::comparison).ok_or_else(|| {
+                DataFusionError::Plan(format!(
+                    "Cannot infer common argument type for regex operation {lhs} {op} {rhs}"
+                ))
+            })
+        }
+        Operator::BitwiseAnd
+        | Operator::BitwiseOr
+        | Operator::BitwiseXor
+        | Operator::BitwiseShiftRight
+        | Operator::BitwiseShiftLeft => {
+            bitwise_coercion(lhs, rhs).map(Signature::uniform).ok_or_else(|| {
+                DataFusionError::Plan(format!(
+                    "Cannot infer common type for bitwise operation {lhs} {op} {rhs}"
+                ))
+            })
+        }
+        Operator::StringConcat => {
+            string_concat_coercion(lhs, rhs).map(Signature::uniform).ok_or_else(|| {
+                DataFusionError::Plan(format!(
+                    "Cannot infer common string type for string concat operation {lhs} {op} {rhs}"
+                ))
+            })
+        }
+        Operator::Plus |
+        Operator::Minus |
+        Operator::Multiply |
+        Operator::Divide|
+        Operator::Modulo =>  {
+            // TODO: this logic would be easier to follow if the functions were inlined
+            if let Some(ret) = mathematics_temporal_result_type(lhs, rhs) {
+                // Temporal arithmetic, e.g. Date32 + Interval
+                Ok(Signature{
+                    lhs: lhs.clone(),
+                    rhs: rhs.clone(),
+                    ret,
+                })
+            } else if let Some(coerced) = temporal_coercion(lhs, rhs) {
+                // Temporal arithmetic by first coercing to a common time representation
+                // e.g. Date32 - Timestamp
+                let ret = mathematics_temporal_result_type(&coerced, &coerced).ok_or_else(|| {
+                    DataFusionError::Plan(format!(
+                        "Cannot get result type for temporal operation {coerced} {op} {coerced}"
+                    ))
+                })?;
+                Ok(Signature{
+                    lhs: coerced.clone(),
+                    rhs: coerced,
+                    ret,
+                })
+            } else if let Some((lhs, rhs)) = math_decimal_coercion(lhs, rhs) {
+                // Decimal arithmetic, e.g. Decimal(10, 2) + Decimal(10, 0)
+                let ret = decimal_op_mathematics_type(op, &lhs, &rhs).ok_or_else(|| {
+                    DataFusionError::Plan(format!(
+                        "Cannot get result type for decimal operation {lhs} {op} {rhs}"
+                    ))
+                })?;
+                Ok(Signature{
+                    lhs,
+                    rhs,
+                    ret,
+                })
+            } else if let Some(numeric) = mathematics_numerical_coercion(lhs, rhs) {
+                // Numeric arithmetic, e.g. Int32 + Int32
+                Ok(Signature::uniform(numeric))
+            } else {
+                Err(DataFusionError::Plan(format!(
+                    "Cannot coerce arithmetic expression {lhs} {op} {rhs} to valid types"
+                )))
+            }
+        }
+    }
+}
+
 /// Returns the result type of applying mathematics operations such as
 /// `+` to arguments of `lhs_type` and `rhs_type`.
 fn mathematics_temporal_result_type(
@@ -38,14 +173,6 @@ fn mathematics_temporal_result_type(
     use arrow::datatypes::IntervalUnit::*;
     use arrow::datatypes::TimeUnit::*;
 
-    if !is_interval(lhs_type)
-        && !is_interval(rhs_type)
-        && !is_datetime(lhs_type)
-        && !is_datetime(rhs_type)
-    {
-        return None;
-    };
-
     match (lhs_type, rhs_type) {
         // datetime +/- interval
         (Interval(_), Timestamp(_, _)) => Some(rhs_type.clone()),
@@ -66,194 +193,68 @@ fn mathematics_temporal_result_type(
         | (Timestamp(Nanosecond, _), Timestamp(Nanosecond, _)) => {
             Some(Interval(MonthDayNano))
         }
-        (Timestamp(_, _), Timestamp(_, _)) => None,
         // date - date
         (Date32, Date32) => Some(Interval(DayTime)),
         (Date64, Date64) => Some(Interval(MonthDayNano)),
-        (Date32, Date64) | (Date64, Date32) => Some(Interval(MonthDayNano)),
-        // date - timestamp, timestamp - date
-        (Date32, Timestamp(_, _))
-        | (Timestamp(_, _), Date32)
-        | (Date64, Timestamp(_, _))
-        | (Timestamp(_, _), Date64) => {
-            // TODO: make get_result_type must after coerce type.
-            // if type isn't coerced, we need get common type, and then get result type.
-            let common_type = temporal_coercion(lhs_type, rhs_type);
-            common_type.and_then(|t| mathematics_temporal_result_type(&t, &t))
-        }
         _ => None,
     }
 }
 
 /// returns the resulting type of a binary expression evaluating the `op` with the left and right hand types
 pub fn get_result_type(
-    lhs_type: &DataType,
+    lhs: &DataType,
     op: &Operator,
-    rhs_type: &DataType,
+    rhs: &DataType,
 ) -> Result<DataType> {
-    if op.is_numerical_operators() && any_decimal(lhs_type, rhs_type) {
-        let (coerced_lhs_type, coerced_rhs_type) =
-            math_decimal_coercion(lhs_type, rhs_type);
-
-        let lhs_type = coerced_lhs_type.unwrap_or(lhs_type.clone());
-        let rhs_type = coerced_rhs_type.unwrap_or(rhs_type.clone());
-
-        if op.is_numerical_operators() {
-            if let Some(result_type) =
-                decimal_op_mathematics_type(op, &lhs_type, &rhs_type)
-            {
-                return Ok(result_type);
-            }
-        }
-    }
-    let result = match op {
-        Operator::And
-        | Operator::Or
-        | Operator::Eq
-        | Operator::NotEq
-        | Operator::Lt
-        | Operator::Gt
-        | Operator::GtEq
-        | Operator::LtEq
-        | Operator::RegexMatch
-        | Operator::RegexIMatch
-        | Operator::RegexNotMatch
-        | Operator::RegexNotIMatch
-        | Operator::IsDistinctFrom
-        | Operator::IsNotDistinctFrom => Some(DataType::Boolean),
-        Operator::Plus | Operator::Minus
-            if is_datetime(lhs_type) && is_datetime(rhs_type)
-                || (is_interval(lhs_type) && is_interval(rhs_type))
-                || (is_datetime(lhs_type) && is_interval(rhs_type))
-                || (is_interval(lhs_type) && is_datetime(rhs_type)) =>
-        {
-            mathematics_temporal_result_type(lhs_type, rhs_type)
-        }
-        // following same with `coerce_types`
-        Operator::BitwiseAnd
-        | Operator::BitwiseOr
-        | Operator::BitwiseXor
-        | Operator::BitwiseShiftRight
-        | Operator::BitwiseShiftLeft => bitwise_coercion(lhs_type, rhs_type),
-        Operator::Plus
-        | Operator::Minus
-        | Operator::Modulo
-        | Operator::Divide
-        | Operator::Multiply => mathematics_numerical_coercion(lhs_type, rhs_type),
-        Operator::StringConcat => string_concat_coercion(lhs_type, rhs_type),
-    };
-
-    result.ok_or(DataFusionError::Plan(format!(
-        "Unsupported argument types. Can not evaluate {lhs_type:?} {op} {rhs_type:?}"
-    )))
+    signature(lhs, op, rhs).map(|sig| sig.ret)
 }
 
-/// Coercion rules for all binary operators. Returns the 'coerce_types'
-/// is returns the type the arguments should be coerced to
-///
-/// Returns None if no suitable type can be found.
-pub fn coerce_types(
-    lhs_type: &DataType,
+/// Returns the coerced input types for a binary expression evaluating the `op` with the left and right hand types
+pub fn get_input_types(
+    lhs: &DataType,
     op: &Operator,
-    rhs_type: &DataType,
-) -> Result<DataType> {
-    // This result MUST be compatible with `binary_coerce`
-    let result = match op {
-        Operator::BitwiseAnd
-        | Operator::BitwiseOr
-        | Operator::BitwiseXor
-        | Operator::BitwiseShiftRight
-        | Operator::BitwiseShiftLeft => bitwise_coercion(lhs_type, rhs_type),
-        Operator::And | Operator::Or => match (lhs_type, rhs_type) {
-            // logical binary boolean operators can only be evaluated in bools or nulls
-            (DataType::Boolean, DataType::Boolean)
-            | (DataType::Null, DataType::Null)
-            | (DataType::Boolean, DataType::Null)
-            | (DataType::Null, DataType::Boolean) => Some(DataType::Boolean),
-            _ => None,
-        },
-        // logical comparison operators have their own rules, and always return a boolean
-        Operator::Eq
-        | Operator::NotEq
-        | Operator::Lt
-        | Operator::Gt
-        | Operator::GtEq
-        | Operator::LtEq
-        | Operator::IsDistinctFrom
-        | Operator::IsNotDistinctFrom => comparison_coercion(lhs_type, rhs_type),
-        Operator::Plus | Operator::Minus
-            if is_interval(lhs_type) && is_interval(rhs_type) =>
-        {
-            temporal_coercion(lhs_type, rhs_type)
-        }
-        Operator::Minus if is_datetime(lhs_type) && is_datetime(rhs_type) => {
-            temporal_coercion(lhs_type, rhs_type)
-        }
-        // for math expressions, the final value of the coercion is also the return type
-        // because coercion favours higher information types
-        Operator::Plus
-        | Operator::Minus
-        | Operator::Modulo
-        | Operator::Divide
-        | Operator::Multiply => mathematics_numerical_coercion(lhs_type, rhs_type),
-        Operator::RegexMatch
-        | Operator::RegexIMatch
-        | Operator::RegexNotMatch
-        | Operator::RegexNotIMatch => regex_coercion(lhs_type, rhs_type),
-        // "||" operator has its own rules, and always return a string type
-        Operator::StringConcat => string_concat_coercion(lhs_type, rhs_type),
-    };
-
-    // re-write the error message of failed coercions to include the operator's information
-    result.ok_or(DataFusionError::Plan(format!("{lhs_type:?} {op} {rhs_type:?} can't be evaluated because there isn't a common type to coerce the types to")))
+    rhs: &DataType,
+) -> Result<(DataType, DataType)> {
+    signature(lhs, op, rhs).map(|sig| (sig.lhs, sig.rhs))
 }
 
 /// Coercion rules for mathematics operators between decimal and non-decimal types.
-pub fn math_decimal_coercion(
+fn math_decimal_coercion(
     lhs_type: &DataType,
     rhs_type: &DataType,
-) -> (Option<DataType>, Option<DataType>) {
+) -> Option<(DataType, DataType)> {
     use arrow::datatypes::DataType::*;
 
-    if both_decimal(lhs_type, rhs_type) {
-        return (None, None);
-    }
-
     match (lhs_type, rhs_type) {
-        (Null, dec_type @ Decimal128(_, _)) => (Some(dec_type.clone()), None),
-        (dec_type @ Decimal128(_, _), Null) => (None, Some(dec_type.clone())),
         (Dictionary(key_type, value_type), _) => {
-            let (value_type, rhs_type) = math_decimal_coercion(value_type, rhs_type);
-            let lhs_type = value_type
-                .map(|value_type| Dictionary(key_type.clone(), Box::new(value_type)));
-            (lhs_type, rhs_type)
+            let (value_type, rhs_type) = math_decimal_coercion(value_type, rhs_type)?;
+            Some((Dictionary(key_type.clone(), Box::new(value_type)), rhs_type))
         }
         (_, Dictionary(key_type, value_type)) => {
-            let (lhs_type, value_type) = math_decimal_coercion(lhs_type, value_type);
-            let rhs_type = value_type
-                .map(|value_type| Dictionary(key_type.clone(), Box::new(value_type)));
-            (lhs_type, rhs_type)
+            let (lhs_type, value_type) = math_decimal_coercion(lhs_type, value_type)?;
+            Some((lhs_type, Dictionary(key_type.clone(), Box::new(value_type))))
         }
-        (Decimal128(_, _), Float32 | Float64) => (Some(Float64), Some(Float64)),
-        (Float32 | Float64, Decimal128(_, _)) => (Some(Float64), Some(Float64)),
-        (Decimal128(_, _), _) => {
-            let converted_decimal_type = coerce_numeric_type_to_decimal(rhs_type);
-            (None, converted_decimal_type)
+        (Null, dec_type @ Decimal128(_, _)) | (dec_type @ Decimal128(_, _), Null) => {
+            Some((dec_type.clone(), dec_type.clone()))
         }
-        (_, Decimal128(_, _)) => {
-            let converted_decimal_type = coerce_numeric_type_to_decimal(lhs_type);
-            (converted_decimal_type, None)
+        (Decimal128(_, _), Decimal128(_, _)) => {
+            Some((lhs_type.clone(), rhs_type.clone()))
         }
-        _ => (None, None),
+        // Unlike with comparison we don't coerce to a decimal in the case of floating point
+        // numbers, instead falling back to floating point arithmetic instead
+        (Decimal128(_, _), Int8 | Int16 | Int32 | Int64) => {
+            Some((lhs_type.clone(), coerce_numeric_type_to_decimal(rhs_type)?))
+        }
+        (Int8 | Int16 | Int32 | Int64, Decimal128(_, _)) => {
+            Some((coerce_numeric_type_to_decimal(lhs_type)?, rhs_type.clone()))
+        }
+        _ => None,
     }
 }
 
 /// Returns the output type of applying bitwise operations such as
 /// `&`, `|`, or `xor`to arguments of `lhs_type` and `rhs_type`.
-pub(crate) fn bitwise_coercion(
-    left_type: &DataType,
-    right_type: &DataType,
-) -> Option<DataType> {
+fn bitwise_coercion(left_type: &DataType, right_type: &DataType) -> Option<DataType> {
     use arrow::datatypes::DataType::*;
 
     if !both_numeric_or_null_and_numeric(left_type, right_type) {
@@ -289,9 +290,7 @@ pub(crate) fn bitwise_coercion(
     }
 }
 
-/// Returns the output type of applying comparison operations such as
-/// `eq`, `not eq`, `lt`, `lteq`, `gt`, and `gteq` to arguments
-/// of `lhs_type` and `rhs_type`.
+/// Coerce `lhs_type` and `rhs_type` to a common type for the purposes of a comparison operation
 pub fn comparison_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<DataType> {
     if lhs_type == rhs_type {
         // same type => equality is possible
@@ -303,11 +302,11 @@ pub fn comparison_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<D
         .or_else(|| string_coercion(lhs_type, rhs_type))
         .or_else(|| null_coercion(lhs_type, rhs_type))
         .or_else(|| string_numeric_coercion(lhs_type, rhs_type))
+        .or_else(|| string_temporal_coercion(lhs_type, rhs_type))
 }
 
-/// Returns the output type of applying numeric operations such as `=`
-/// to arguments `lhs_type` and `rhs_type` if one is numeric and one
-/// is `Utf8`/`LargeUtf8`.
+/// Coerce `lhs_type` and `rhs_type` to a common type for the purposes of a comparison operation
+/// where one is numeric and one is `Utf8`/`LargeUtf8`.
 fn string_numeric_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<DataType> {
     use arrow::datatypes::DataType::*;
     match (lhs_type, rhs_type) {
@@ -319,8 +318,48 @@ fn string_numeric_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<D
     }
 }
 
-/// Returns the output type of applying numeric operations such as `=`
-/// to arguments `lhs_type` and `rhs_type` if both are numeric
+/// Coerce `lhs_type` and `rhs_type` to a common type for the purposes of a comparison operation
+/// where one is temporal and one is `Utf8`/`LargeUtf8`.
+///
+/// Note this cannot be performed in case of arithmetic as there is insufficient information
+/// to correctly determine the type of argument. Consider
+///
+/// ```sql
+/// timestamp > now() - '1 month'
+/// interval > now() - '1970-01-2021'
+/// ```
+///
+/// In the absence of a full type inference system, we can't determine the correct type
+/// to parse the string argument
+fn string_temporal_coercion(
+    lhs_type: &DataType,
+    rhs_type: &DataType,
+) -> Option<DataType> {
+    use arrow::datatypes::DataType::*;
+    match (lhs_type, rhs_type) {
+        (Utf8, Date32) | (Date32, Utf8) => Some(Date32),
+        (Utf8, Date64) | (Date64, Utf8) => Some(Date64),
+        (Utf8, Time32(unit)) | (Time32(unit), Utf8) => {
+            match is_time_with_valid_unit(Time32(unit.clone())) {
+                false => None,
+                true => Some(Time32(unit.clone())),
+            }
+        }
+        (Utf8, Time64(unit)) | (Time64(unit), Utf8) => {
+            match is_time_with_valid_unit(Time64(unit.clone())) {
+                false => None,
+                true => Some(Time64(unit.clone())),
+            }
+        }
+        (Timestamp(_, tz), Utf8) | (Utf8, Timestamp(_, tz)) => {
+            Some(Timestamp(TimeUnit::Nanosecond, tz.clone()))
+        }
+        _ => None,
+    }
+}
+
+/// Coerce `lhs_type` and `rhs_type` to a common type for the purposes of a comparison operation
+/// where one both are numeric
 fn comparison_binary_numeric_coercion(
     lhs_type: &DataType,
     rhs_type: &DataType,
@@ -338,7 +377,7 @@ fn comparison_binary_numeric_coercion(
     // these are ordered from most informative to least informative so
     // that the coercion does not lose information via truncation
     match (lhs_type, rhs_type) {
-        // support decimal data type for comparison operation
+        // Prefer decimal data type over floating point for comparison operation
         (Decimal128(_, _), Decimal128(_, _)) => {
             get_wider_decimal_type(lhs_type, rhs_type)
         }
@@ -381,26 +420,14 @@ fn comparison_binary_numeric_coercion(
     }
 }
 
-/// Returns the output type of applying numeric operations such as `=`
-/// to a decimal type `decimal_type` and `other_type`
+/// Coerce `lhs_type` and `rhs_type` to a common type for the purposes of
+/// a comparison operation where one is a decimal
 fn get_comparison_common_decimal_type(
     decimal_type: &DataType,
     other_type: &DataType,
 ) -> Option<DataType> {
     use arrow::datatypes::DataType::*;
-    let other_decimal_type = &match other_type {
-        // This conversion rule is from spark
-        // https://github.com/apache/spark/blob/1c81ad20296d34f137238dadd67cc6ae405944eb/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala#L127
-        Int8 => Decimal128(3, 0),
-        Int16 => Decimal128(5, 0),
-        Int32 => Decimal128(10, 0),
-        Int64 => Decimal128(20, 0),
-        Float32 => Decimal128(14, 7),
-        Float64 => Decimal128(30, 15),
-        _ => {
-            return None;
-        }
-    };
+    let other_decimal_type = coerce_numeric_type_to_decimal(other_type)?;
     match (decimal_type, &other_decimal_type) {
         (d1 @ Decimal128(_, _), d2 @ Decimal128(_, _)) => get_wider_decimal_type(d1, d2),
         _ => None,
@@ -430,6 +457,8 @@ fn get_wider_decimal_type(
 /// Now, we just support the signed integer type and floating-point type.
 fn coerce_numeric_type_to_decimal(numeric_type: &DataType) -> Option<DataType> {
     use arrow::datatypes::DataType::*;
+    // This conversion rule is from spark
+    // https://github.com/apache/spark/blob/1c81ad20296d34f137238dadd67cc6ae405944eb/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala#L127
     match numeric_type {
         Int8 => Some(Decimal128(3, 0)),
         Int16 => Some(Decimal128(5, 0)),
@@ -499,6 +528,7 @@ pub fn coercion_decimal_mathematics_type(
     left_decimal_type: &DataType,
     right_decimal_type: &DataType,
 ) -> Option<DataType> {
+    // TODO: Move this logic into kernel implementations
     use arrow::datatypes::DataType::*;
     match (left_decimal_type, right_decimal_type) {
         // The promotion rule from spark
@@ -518,7 +548,7 @@ pub fn coercion_decimal_mathematics_type(
     }
 }
 
-/// Returns the output type of applying mathematics operations on decimal types.
+/// Returns the output type of applying mathematics operations on two decimal types.
 /// The rule is from spark. Note that this is different to the coerced type applied
 /// to two sides of the arithmetic operation.
 pub fn decimal_op_mathematics_type(
@@ -605,29 +635,6 @@ fn both_numeric_or_null_and_numeric(lhs_type: &DataType, rhs_type: &DataType) ->
     }
 }
 
-/// Determine if at least of one of lhs and rhs is decimal, and the other must be NULL or decimal
-fn both_decimal(lhs_type: &DataType, rhs_type: &DataType) -> bool {
-    use arrow::datatypes::DataType::*;
-    match (lhs_type, rhs_type) {
-        (_, Null) => is_decimal(lhs_type),
-        (Null, _) => is_decimal(rhs_type),
-        (Decimal128(_, _), Decimal128(_, _)) => true,
-        (Dictionary(_, value_type), _) => is_decimal(value_type) && is_decimal(rhs_type),
-        (_, Dictionary(_, value_type)) => is_decimal(lhs_type) && is_decimal(value_type),
-        _ => false,
-    }
-}
-
-/// Determine if at least of one of lhs and rhs is decimal
-pub fn any_decimal(lhs_type: &DataType, rhs_type: &DataType) -> bool {
-    use arrow::datatypes::DataType::*;
-    match (lhs_type, rhs_type) {
-        (Dictionary(_, value_type), _) => is_decimal(value_type) || is_decimal(rhs_type),
-        (_, Dictionary(_, value_type)) => is_decimal(lhs_type) || is_decimal(value_type),
-        (_, _) => is_decimal(lhs_type) || is_decimal(rhs_type),
-    }
-}
-
 /// Coercion rules for Dictionaries: the type that both lhs and rhs
 /// can be casted to for the purpose of a computation.
 ///
@@ -743,30 +750,10 @@ fn temporal_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<DataTyp
     use arrow::datatypes::IntervalUnit::*;
     use arrow::datatypes::TimeUnit::*;
 
-    if lhs_type == rhs_type {
-        return Some(lhs_type.clone());
-    }
     match (lhs_type, rhs_type) {
         // interval +/-
         (Interval(_), Interval(_)) => Some(Interval(MonthDayNano)),
         (Date64, Date32) | (Date32, Date64) => Some(Date64),
-        (Utf8, Date32) | (Date32, Utf8) => Some(Date32),
-        (Utf8, Date64) | (Date64, Utf8) => Some(Date64),
-        (Utf8, Time32(unit)) | (Time32(unit), Utf8) => {
-            match is_time_with_valid_unit(Time32(unit.clone())) {
-                false => None,
-                true => Some(Time32(unit.clone())),
-            }
-        }
-        (Utf8, Time64(unit)) | (Time64(unit), Utf8) => {
-            match is_time_with_valid_unit(Time64(unit.clone())) {
-                false => None,
-                true => Some(Time64(unit.clone())),
-            }
-        }
-        (Timestamp(_, tz), Utf8) | (Utf8, Timestamp(_, tz)) => {
-            Some(Timestamp(Nanosecond, tz.clone()))
-        }
         (Timestamp(_, None), Date32) | (Date32, Timestamp(_, None)) => {
             Some(Timestamp(Nanosecond, None))
         }
@@ -832,7 +819,6 @@ fn null_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<DataType> {
 mod tests {
     use arrow::datatypes::DataType;
 
-    use datafusion_common::assert_contains;
     use datafusion_common::DataFusionError;
     use datafusion_common::Result;
 
@@ -843,10 +829,13 @@ mod tests {
     #[test]
     fn test_coercion_error() -> Result<()> {
         let result_type =
-            coerce_types(&DataType::Float32, &Operator::Plus, &DataType::Utf8);
+            get_input_types(&DataType::Float32, &Operator::Plus, &DataType::Utf8);
 
         if let Err(DataFusionError::Plan(e)) = result_type {
-            assert_eq!(e, "Float32 + Utf8 can't be evaluated because there isn't a common type to coerce the types to");
+            assert_eq!(
+                e,
+                "Cannot coerce arithmetic expression Float32 + Utf8 to valid types"
+            );
             Ok(())
         } else {
             Err(DataFusionError::Internal(
@@ -891,12 +880,14 @@ mod tests {
         for (i, input_type) in input_types.iter().enumerate() {
             let expect_type = &result_types[i];
             for op in comparison_op_types {
-                let result_type = coerce_types(&input_decimal, &op, input_type)?;
-                assert_eq!(expect_type, &result_type);
+                let (lhs, rhs) = get_input_types(&input_decimal, &op, input_type)?;
+                assert_eq!(expect_type, &lhs);
+                assert_eq!(expect_type, &rhs);
             }
         }
         // negative test
-        let result_type = coerce_types(&input_decimal, &Operator::Eq, &DataType::Boolean);
+        let result_type =
+            get_input_types(&input_decimal, &Operator::Eq, &DataType::Boolean);
         assert!(result_type.is_err());
         Ok(())
     }
@@ -1017,24 +1008,27 @@ mod tests {
 
     macro_rules! test_coercion_binary_rule {
         ($A_TYPE:expr, $B_TYPE:expr, $OP:expr, $C_TYPE:expr) => {{
-            let result = coerce_types(&$A_TYPE, &$OP, &$B_TYPE)?;
-            assert_eq!(result, $C_TYPE);
+            let (lhs, rhs) = get_input_types(&$A_TYPE, &$OP, &$B_TYPE)?;
+            assert_eq!(lhs, $C_TYPE);
+            assert_eq!(rhs, $C_TYPE);
         }};
     }
 
     #[test]
     fn test_date_timestamp_arithmetic_error() -> Result<()> {
-        let common_type = coerce_types(
+        let (lhs, rhs) = get_input_types(
             &DataType::Timestamp(TimeUnit::Nanosecond, None),
             &Operator::Minus,
             &DataType::Timestamp(TimeUnit::Millisecond, None),
         )?;
-        assert_eq!(common_type.to_string(), "Timestamp(Millisecond, None)");
+        assert_eq!(lhs.to_string(), "Timestamp(Millisecond, None)");
+        assert_eq!(rhs.to_string(), "Timestamp(Millisecond, None)");
 
-        let err = coerce_types(&DataType::Date32, &Operator::Plus, &DataType::Date64)
-            .unwrap_err()
-            .to_string();
-        assert_contains!(&err, "Date32 + Date64 can't be evaluated because there isn't a common type to coerce the types to");
+        let (lhs, rhs) =
+            get_input_types(&DataType::Date32, &Operator::Plus, &DataType::Date64)
+                .unwrap();
+        assert_eq!(lhs.to_string(), "Date64");
+        assert_eq!(rhs.to_string(), "Date64");
 
         Ok(())
     }
@@ -1234,18 +1228,15 @@ mod tests {
         lhs_type: DataType,
         rhs_type: DataType,
         mathematics_op: Operator,
-        expected_lhs_type: Option<DataType>,
-        expected_rhs_type: Option<DataType>,
+        expected_lhs_type: DataType,
+        expected_rhs_type: DataType,
         expected_coerced_type: Option<DataType>,
         expected_output_type: DataType,
     ) {
         // The coerced types for lhs and rhs, if any of them is not decimal
-        let (l, r) = math_decimal_coercion(&lhs_type, &rhs_type);
-        assert_eq!(l, expected_lhs_type);
-        assert_eq!(r, expected_rhs_type);
-
-        let lhs_type = l.unwrap_or(lhs_type);
-        let rhs_type = r.unwrap_or(rhs_type);
+        let (lhs_type, rhs_type) = math_decimal_coercion(&lhs_type, &rhs_type).unwrap();
+        assert_eq!(lhs_type, expected_lhs_type);
+        assert_eq!(rhs_type, expected_rhs_type);
 
         // The coerced type of decimal math expression, applied during expression evaluation
         let coerced_type =
@@ -1264,8 +1255,8 @@ mod tests {
             DataType::Decimal128(10, 2),
             DataType::Decimal128(10, 2),
             Operator::Plus,
-            None,
-            None,
+            DataType::Decimal128(10, 2),
+            DataType::Decimal128(10, 2),
             Some(DataType::Decimal128(11, 2)),
             DataType::Decimal128(11, 2),
         );
@@ -1274,8 +1265,8 @@ mod tests {
             DataType::Int32,
             DataType::Decimal128(10, 2),
             Operator::Plus,
-            Some(DataType::Decimal128(10, 0)),
-            None,
+            DataType::Decimal128(10, 0),
+            DataType::Decimal128(10, 2),
             Some(DataType::Decimal128(13, 2)),
             DataType::Decimal128(13, 2),
         );
@@ -1284,8 +1275,8 @@ mod tests {
             DataType::Int32,
             DataType::Decimal128(10, 2),
             Operator::Minus,
-            Some(DataType::Decimal128(10, 0)),
-            None,
+            DataType::Decimal128(10, 0),
+            DataType::Decimal128(10, 2),
             Some(DataType::Decimal128(13, 2)),
             DataType::Decimal128(13, 2),
         );
@@ -1294,8 +1285,8 @@ mod tests {
             DataType::Int32,
             DataType::Decimal128(10, 2),
             Operator::Multiply,
-            Some(DataType::Decimal128(10, 0)),
-            None,
+            DataType::Decimal128(10, 0),
+            DataType::Decimal128(10, 2),
             None,
             DataType::Decimal128(21, 2),
         );
@@ -1304,8 +1295,8 @@ mod tests {
             DataType::Int32,
             DataType::Decimal128(10, 2),
             Operator::Divide,
-            Some(DataType::Decimal128(10, 0)),
-            None,
+            DataType::Decimal128(10, 0),
+            DataType::Decimal128(10, 2),
             Some(DataType::Decimal128(12, 2)),
             DataType::Decimal128(23, 11),
         );
@@ -1314,8 +1305,8 @@ mod tests {
             DataType::Int32,
             DataType::Decimal128(10, 2),
             Operator::Modulo,
-            Some(DataType::Decimal128(10, 0)),
-            None,
+            DataType::Decimal128(10, 0),
+            DataType::Decimal128(10, 2),
             Some(DataType::Decimal128(12, 2)),
             DataType::Decimal128(10, 2),
         );
diff --git a/datafusion/optimizer/src/analyzer/type_coercion.rs b/datafusion/optimizer/src/analyzer/type_coercion.rs
index 412abbfae6..61153c0d36 100644
--- a/datafusion/optimizer/src/analyzer/type_coercion.rs
+++ b/datafusion/optimizer/src/analyzer/type_coercion.rs
@@ -32,7 +32,7 @@ use datafusion_expr::expr_rewriter::rewrite_preserving_name;
 use datafusion_expr::expr_schema::cast_subquery;
 use datafusion_expr::logical_plan::Subquery;
 use datafusion_expr::type_coercion::binary::{
-    any_decimal, coerce_types, comparison_coercion, like_coercion, math_decimal_coercion,
+    comparison_coercion, get_input_types, like_coercion,
 };
 use datafusion_expr::type_coercion::functions::data_types;
 use datafusion_expr::type_coercion::other::{
@@ -230,72 +230,18 @@ impl TreeNodeRewriter for TypeCoercionRewriter {
                 let expr = Expr::ILike(Like::new(negated, expr, pattern, escape_char));
                 Ok(expr)
             }
-            Expr::BinaryExpr(BinaryExpr {
-                ref left,
-                op,
-                ref right,
-            }) => {
-                // this is a workaround for https://github.com/apache/arrow-datafusion/issues/3419
-                let left_type = left.get_type(&self.schema)?;
-                let right_type = right.get_type(&self.schema)?;
-                match (&left_type, &right_type) {
-                    // Handle some case about Interval.
-                    (
-                        DataType::Date32 | DataType::Date64 | DataType::Timestamp(_, _),
-                        &DataType::Interval(_),
-                    ) if matches!(op, Operator::Plus | Operator::Minus) => Ok(expr),
-                    (
-                        &DataType::Interval(_),
-                        DataType::Date32 | DataType::Date64 | DataType::Timestamp(_, _),
-                    ) if matches!(op, Operator::Plus) => Ok(expr),
-                    (DataType::Timestamp(_, _), DataType::Timestamp(_, _))
-                        if op.is_numerical_operators() =>
-                    {
-                        if matches!(op, Operator::Minus) {
-                            Ok(expr)
-                        } else {
-                            Err(DataFusionError::Internal(format!(
-                                "Unsupported operation {op:?} between {left_type:?} and {right_type:?}"
-                            )))
-                        }
-                    }
-                    // For numerical operations between decimals, we don't coerce the types.
-                    // But if only one of the operands is decimal, we cast the other operand to decimal
-                    // if the other operand is integer. If the other operand is float, we cast the
-                    // decimal operand to float.
-                    (lhs_type, rhs_type)
-                        if op.is_numerical_operators()
-                            && any_decimal(lhs_type, rhs_type) =>
-                    {
-                        let (coerced_lhs_type, coerced_rhs_type) =
-                            math_decimal_coercion(lhs_type, rhs_type);
-                        let new_left = if let Some(lhs_type) = coerced_lhs_type {
-                            left.clone().cast_to(&lhs_type, &self.schema)?
-                        } else {
-                            left.as_ref().clone()
-                        };
-                        let new_right = if let Some(rhs_type) = coerced_rhs_type {
-                            right.clone().cast_to(&rhs_type, &self.schema)?
-                        } else {
-                            right.as_ref().clone()
-                        };
-                        let expr = Expr::BinaryExpr(BinaryExpr::new(
-                            Box::new(new_left),
-                            op,
-                            Box::new(new_right),
-                        ));
-                        Ok(expr)
-                    }
-                    _ => {
-                        let common_type = coerce_types(&left_type, &op, &right_type)?;
-                        let expr = Expr::BinaryExpr(BinaryExpr::new(
-                            Box::new(left.clone().cast_to(&common_type, &self.schema)?),
-                            op,
-                            Box::new(right.clone().cast_to(&common_type, &self.schema)?),
-                        ));
-                        Ok(expr)
-                    }
-                }
+            Expr::BinaryExpr(BinaryExpr { left, op, right }) => {
+                let (left_type, right_type) = get_input_types(
+                    &left.get_type(&self.schema)?,
+                    &op,
+                    &right.get_type(&self.schema)?,
+                )?;
+
+                Ok(Expr::BinaryExpr(BinaryExpr::new(
+                    Box::new(left.cast_to(&left_type, &self.schema)?),
+                    op,
+                    Box::new(right.cast_to(&right_type, &self.schema)?),
+                )))
             }
             Expr::Between(Between {
                 expr,
@@ -566,7 +512,7 @@ fn coerce_window_frame(
 // The above op will be rewrite to the binary op when creating the physical op.
 fn get_casted_expr_for_bool_op(expr: &Expr, schema: &DFSchemaRef) -> Result<Expr> {
     let left_type = expr.get_type(schema)?;
-    coerce_types(&left_type, &Operator::IsDistinctFrom, &DataType::Boolean)?;
+    get_input_types(&left_type, &Operator::IsDistinctFrom, &DataType::Boolean)?;
     expr.clone().cast_to(&DataType::Boolean, schema)
 }
 
@@ -1108,9 +1054,9 @@ mod test {
 
         let empty = empty_with_type(DataType::Int64);
         let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?);
-        let err = assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), &plan, "");
-        assert!(err.is_err());
-        assert!(err.unwrap_err().to_string().contains("Int64 IS DISTINCT FROM Boolean can't be evaluated because there isn't a common type to coerce the types to"));
+        let ret = assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), &plan, "");
+        let err = ret.unwrap_err().to_string();
+        assert!(err.contains("Cannot infer common argument type for comparison operation Int64 IS DISTINCT FROM Boolean"), "{err}");
 
         // is not true
         let expr = col("a").is_not_true();
@@ -1210,9 +1156,9 @@ mod test {
 
         let empty = empty_with_type(DataType::Utf8);
         let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?);
-        let err = assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), &plan, expected);
-        assert!(err.is_err());
-        assert!(err.unwrap_err().to_string().contains("Utf8 IS DISTINCT FROM Boolean can't be evaluated because there isn't a common type to coerce the types to"));
+        let ret = assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), &plan, expected);
+        let err = ret.unwrap_err().to_string();
+        assert!(err.contains("Cannot infer common argument type for comparison operation Utf8 IS DISTINCT FROM Boolean"), "{err}");
 
         // is not unknown
         let expr = col("a").is_not_unknown();
diff --git a/datafusion/physical-expr/src/expressions/binary.rs b/datafusion/physical-expr/src/expressions/binary.rs
index e3bbefbcd3..c764692da3 100644
--- a/datafusion/physical-expr/src/expressions/binary.rs
+++ b/datafusion/physical-expr/src/expressions/binary.rs
@@ -1343,7 +1343,7 @@ mod tests {
         ArrowNumericType, Decimal128Type, Field, Int32Type, SchemaRef,
     };
     use datafusion_common::{ColumnStatistics, Result, Statistics};
-    use datafusion_expr::type_coercion::binary::{coerce_types, math_decimal_coercion};
+    use datafusion_expr::type_coercion::binary::get_input_types;
 
     // Create a binary expression without coercion. Used here when we do not want to coerce the expressions
     // to valid types. Usage can result in an execution (after plan) error.
@@ -1447,10 +1447,10 @@ mod tests {
             ]);
             let a = $A_ARRAY::from($A_VEC);
             let b = $B_ARRAY::from($B_VEC);
-            let common_type = coerce_types(&$A_TYPE, &$OP, &$B_TYPE)?;
+            let (lhs, rhs) = get_input_types(&$A_TYPE, &$OP, &$B_TYPE)?;
 
-            let left = try_cast(col("a", &schema)?, &schema, common_type.clone())?;
-            let right = try_cast(col("b", &schema)?, &schema, common_type)?;
+            let left = try_cast(col("a", &schema)?, &schema, lhs)?;
+            let right = try_cast(col("b", &schema)?, &schema, rhs)?;
 
             // verify that we can construct the expression
             let expression = binary(left, $OP, right, &schema)?;
@@ -2964,10 +2964,10 @@ mod tests {
     ) -> Result<()> {
         let left_type = left.data_type();
         let right_type = right.data_type();
-        let common_type = coerce_types(left_type, &op, right_type)?;
+        let (lhs, rhs) = get_input_types(left_type, &op, right_type)?;
 
-        let left_expr = try_cast(col("a", schema)?, schema, common_type.clone())?;
-        let right_expr = try_cast(col("b", schema)?, schema, common_type)?;
+        let left_expr = try_cast(col("a", schema)?, schema, lhs)?;
+        let right_expr = try_cast(col("b", schema)?, schema, rhs)?;
         let arithmetic_op = binary_simple(left_expr, op, right_expr, schema);
         let data: Vec<ArrayRef> = vec![left.clone(), right.clone()];
         let batch = RecordBatch::try_new(schema.clone(), data)?;
@@ -2986,17 +2986,10 @@ mod tests {
         expected: &BooleanArray,
     ) -> Result<()> {
         let scalar = lit(scalar.clone());
-        let op_type = coerce_types(&scalar.data_type(schema)?, &op, arr.data_type())?;
-        let left_expr = if op_type.eq(&scalar.data_type(schema)?) {
-            scalar
-        } else {
-            try_cast(scalar, schema, op_type.clone())?
-        };
-        let right_expr = if op_type.eq(arr.data_type()) {
-            col("a", schema)?
-        } else {
-            try_cast(col("a", schema)?, schema, op_type)?
-        };
+        let (lhs, rhs) =
+            get_input_types(&scalar.data_type(schema)?, &op, arr.data_type())?;
+        let left_expr = try_cast(scalar, schema, lhs)?;
+        let right_expr = try_cast(col("a", schema)?, schema, rhs)?;
 
         let arithmetic_op = binary_simple(left_expr, op, right_expr, schema);
         let batch = RecordBatch::try_new(Arc::clone(schema), vec![Arc::clone(arr)])?;
@@ -3015,17 +3008,10 @@ mod tests {
         expected: &BooleanArray,
     ) -> Result<()> {
         let scalar = lit(scalar.clone());
-        let op_type = coerce_types(arr.data_type(), &op, &scalar.data_type(schema)?)?;
-        let right_expr = if op_type.eq(&scalar.data_type(schema)?) {
-            scalar
-        } else {
-            try_cast(scalar, schema, op_type.clone())?
-        };
-        let left_expr = if op_type.eq(arr.data_type()) {
-            col("a", schema)?
-        } else {
-            try_cast(col("a", schema)?, schema, op_type)?
-        };
+        let (lhs, rhs) =
+            get_input_types(arr.data_type(), &op, &scalar.data_type(schema)?)?;
+        let left_expr = try_cast(col("a", schema)?, schema, lhs)?;
+        let right_expr = try_cast(scalar, schema, rhs)?;
 
         let arithmetic_op = binary_simple(left_expr, op, right_expr, schema);
         let batch = RecordBatch::try_new(Arc::clone(schema), vec![Arc::clone(arr)])?;
@@ -4077,26 +4063,11 @@ mod tests {
         op: Operator,
         expected: ArrayRef,
     ) -> Result<()> {
-        let (lhs_op_type, rhs_op_type) =
-            math_decimal_coercion(left.data_type(), right.data_type());
+        let (lhs_type, rhs_type) =
+            get_input_types(left.data_type(), &op, right.data_type()).unwrap();
 
-        let (left_expr, lhs_type) = if let Some(lhs_op_type) = lhs_op_type {
-            (
-                try_cast(col("a", schema)?, schema, lhs_op_type.clone())?,
-                lhs_op_type,
-            )
-        } else {
-            (col("a", schema)?, left.data_type().clone())
-        };
-
-        let (right_expr, rhs_type) = if let Some(rhs_op_type) = rhs_op_type {
-            (
-                try_cast(col("b", schema)?, schema, rhs_op_type.clone())?,
-                rhs_op_type,
-            )
-        } else {
-            (col("b", schema)?, right.data_type().clone())
-        };
+        let left_expr = try_cast(col("a", schema)?, schema, lhs_type.clone())?;
+        let right_expr = try_cast(col("b", schema)?, schema, rhs_type.clone())?;
 
         let coerced_schema = Schema::new(vec![
             Field::new(