You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by ag...@apache.org on 2023/05/13 13:56:46 UTC
[arrow-datafusion] branch main updated: refine decimal multiply, avoid cast to wider type (#6331)
This is an automated email from the ASF dual-hosted git repository.
agrove pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new 063f99fd61 refine decimal multiply, avoid cast to wider type (#6331)
063f99fd61 is described below
commit 063f99fd61aa8f2854dc138d16864ae485a7116e
Author: Ken, Wang <mi...@gmail.com>
AuthorDate: Sat May 13 21:56:41 2023 +0800
refine decimal multiply, avoid cast to wider type (#6331)
* refine decimal multiply, avoid cast to wider type
* fix clippy
* fix fmt
---
datafusion/expr/src/type_coercion/binary.rs | 21 ++++++++++-----------
.../src/expressions/binary/kernels_arrow.rs | 22 +++-------------------
2 files changed, 13 insertions(+), 30 deletions(-)
diff --git a/datafusion/expr/src/type_coercion/binary.rs b/datafusion/expr/src/type_coercion/binary.rs
index fa912b777a..962434d652 100644
--- a/datafusion/expr/src/type_coercion/binary.rs
+++ b/datafusion/expr/src/type_coercion/binary.rs
@@ -514,7 +514,7 @@ pub fn coercion_decimal_mathematics_type(
left_decimal_type,
right_decimal_type,
),
- Operator::Multiply | Operator::Divide | Operator::Modulo => {
+ Operator::Divide | Operator::Modulo => {
get_wider_decimal_type(left_decimal_type, right_decimal_type)
}
_ => None,
@@ -946,7 +946,7 @@ mod tests {
&left_decimal_type,
&right_decimal_type,
);
- assert_eq!(DataType::Decimal128(20, 4), result.unwrap());
+ assert_eq!(None, result);
let result =
decimal_op_mathematics_type(&op, &left_decimal_type, &right_decimal_type);
assert_eq!(DataType::Decimal128(31, 7), result.unwrap());
@@ -1232,7 +1232,7 @@ mod tests {
mathematics_op: Operator,
expected_lhs_type: Option<DataType>,
expected_rhs_type: Option<DataType>,
- expected_coerced_type: DataType,
+ expected_coerced_type: Option<DataType>,
expected_output_type: DataType,
) {
// The coerced types for lhs and rhs, if any of them is not decimal
@@ -1245,8 +1245,7 @@ mod tests {
// The coerced type of decimal math expression, applied during expression evaluation
let coerced_type =
- coercion_decimal_mathematics_type(&mathematics_op, &lhs_type, &rhs_type)
- .unwrap();
+ coercion_decimal_mathematics_type(&mathematics_op, &lhs_type, &rhs_type);
assert_eq!(coerced_type, expected_coerced_type);
// The output type of decimal math expression
@@ -1263,7 +1262,7 @@ mod tests {
Operator::Plus,
None,
None,
- DataType::Decimal128(11, 2),
+ Some(DataType::Decimal128(11, 2)),
DataType::Decimal128(11, 2),
);
@@ -1273,7 +1272,7 @@ mod tests {
Operator::Plus,
Some(DataType::Decimal128(10, 0)),
None,
- DataType::Decimal128(13, 2),
+ Some(DataType::Decimal128(13, 2)),
DataType::Decimal128(13, 2),
);
@@ -1283,7 +1282,7 @@ mod tests {
Operator::Minus,
Some(DataType::Decimal128(10, 0)),
None,
- DataType::Decimal128(13, 2),
+ Some(DataType::Decimal128(13, 2)),
DataType::Decimal128(13, 2),
);
@@ -1293,7 +1292,7 @@ mod tests {
Operator::Multiply,
Some(DataType::Decimal128(10, 0)),
None,
- DataType::Decimal128(12, 2),
+ None,
DataType::Decimal128(21, 2),
);
@@ -1303,7 +1302,7 @@ mod tests {
Operator::Divide,
Some(DataType::Decimal128(10, 0)),
None,
- DataType::Decimal128(12, 2),
+ Some(DataType::Decimal128(12, 2)),
DataType::Decimal128(23, 11),
);
@@ -1313,7 +1312,7 @@ mod tests {
Operator::Modulo,
Some(DataType::Decimal128(10, 0)),
None,
- DataType::Decimal128(12, 2),
+ Some(DataType::Decimal128(12, 2)),
DataType::Decimal128(10, 2),
);
diff --git a/datafusion/physical-expr/src/expressions/binary/kernels_arrow.rs b/datafusion/physical-expr/src/expressions/binary/kernels_arrow.rs
index 2852d617bf..90fca17157 100644
--- a/datafusion/physical-expr/src/expressions/binary/kernels_arrow.rs
+++ b/datafusion/physical-expr/src/expressions/binary/kernels_arrow.rs
@@ -35,9 +35,7 @@ use arrow_schema::DataType;
use datafusion_common::cast::{as_date32_array, as_date64_array, as_decimal128_array};
use datafusion_common::scalar::{date32_add, date64_add};
use datafusion_common::{DataFusionError, Result, ScalarValue};
-use datafusion_expr::type_coercion::binary::decimal_op_mathematics_type;
use datafusion_expr::ColumnarValue;
-use datafusion_expr::Operator;
use std::cmp::min;
use std::sync::Arc;
@@ -469,24 +467,8 @@ pub(crate) fn multiply_decimal_dyn_scalar(
result_type: &DataType,
) -> Result<ArrayRef> {
let (precision, scale) = get_precision_scale(result_type)?;
-
- let op_type = decimal_op_mathematics_type(
- &Operator::Multiply,
- left.data_type(),
- left.data_type(),
- )
- .unwrap();
- let (_, op_scale) = get_precision_scale(&op_type)?;
-
let array = multiply_scalar_dyn::<Decimal128Type>(left, right)?;
-
- if op_scale > scale {
- let div = 10_i128.pow((op_scale - scale) as u32);
- let array = divide_scalar_dyn::<Decimal128Type>(&array, div)?;
- decimal_array_with_precision_scale(array, precision, scale)
- } else {
- decimal_array_with_precision_scale(array, precision, scale)
- }
+ decimal_array_with_precision_scale(array, precision, scale)
}
pub(crate) fn divide_decimal_dyn_scalar(
@@ -703,6 +685,8 @@ pub(crate) fn modulus_decimal_dyn_scalar(
#[cfg(test)]
mod tests {
use super::*;
+ use datafusion_expr::type_coercion::binary::decimal_op_mathematics_type;
+ use datafusion_expr::Operator;
fn create_decimal_array(
array: &[Option<i128>],