You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by tu...@apache.org on 2023/06/30 10:02:15 UTC
[arrow-datafusion] branch main updated: Consistently coerce dictionaries for arithmetic (#6785)
This is an automated email from the ASF dual-hosted git repository.
tustvold pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new 2f78536cd0 Consistently coerce dictionaries for arithmetic (#6785)
2f78536cd0 is described below
commit 2f78536cd0050dd321153c5cc9de0bcee7aa5b3f
Author: Raphael Taylor-Davies <17...@users.noreply.github.com>
AuthorDate: Fri Jun 30 11:02:09 2023 +0100
Consistently coerce dictionaries for arithmetic (#6785)
* Coerce dictionaries for arithmetic
* Clippy
---
datafusion/expr/src/type_coercion/binary.rs | 14 +-
datafusion/physical-expr/Cargo.toml | 2 +-
datafusion/physical-expr/src/expressions/binary.rs | 115 +++++----------
.../src/expressions/binary/kernels_arrow.rs | 159 ++-------------------
4 files changed, 53 insertions(+), 237 deletions(-)
diff --git a/datafusion/expr/src/type_coercion/binary.rs b/datafusion/expr/src/type_coercion/binary.rs
index 64ebf8b559..c510822445 100644
--- a/datafusion/expr/src/type_coercion/binary.rs
+++ b/datafusion/expr/src/type_coercion/binary.rs
@@ -226,13 +226,13 @@ fn math_decimal_coercion(
use arrow::datatypes::DataType::*;
match (lhs_type, rhs_type) {
- (Dictionary(key_type, value_type), _) => {
+ (Dictionary(_, value_type), _) => {
let (value_type, rhs_type) = math_decimal_coercion(value_type, rhs_type)?;
- Some((Dictionary(key_type.clone(), Box::new(value_type)), rhs_type))
+ Some((value_type, rhs_type))
}
- (_, Dictionary(key_type, value_type)) => {
+ (_, Dictionary(_, value_type)) => {
let (lhs_type, value_type) = math_decimal_coercion(lhs_type, value_type)?;
- Some((lhs_type, Dictionary(key_type.clone(), Box::new(value_type))))
+ Some((lhs_type, value_type))
}
(Null, dec_type @ Decimal128(_, _)) | (dec_type @ Decimal128(_, _), Null) => {
Some((dec_type.clone(), dec_type.clone()))
@@ -490,10 +490,8 @@ fn mathematics_numerical_coercion(
(Dictionary(_, lhs_value_type), Dictionary(_, rhs_value_type)) => {
mathematics_numerical_coercion(lhs_value_type, rhs_value_type)
}
- (Dictionary(key_type, value_type), _) => {
- let value_type = mathematics_numerical_coercion(value_type, rhs_type);
- value_type
- .map(|value_type| Dictionary(key_type.clone(), Box::new(value_type)))
+ (Dictionary(_, value_type), _) => {
+ mathematics_numerical_coercion(value_type, rhs_type)
}
(_, Dictionary(_, value_type)) => {
mathematics_numerical_coercion(lhs_type, value_type)
diff --git a/datafusion/physical-expr/Cargo.toml b/datafusion/physical-expr/Cargo.toml
index e51af2d6d7..04ba2b9e38 100644
--- a/datafusion/physical-expr/Cargo.toml
+++ b/datafusion/physical-expr/Cargo.toml
@@ -37,7 +37,7 @@ crypto_expressions = ["md-5", "sha2", "blake2", "blake3"]
default = ["crypto_expressions", "regex_expressions", "unicode_expressions"]
# Enables support for non-scalar, binary operations on dictionaries
# Note: this results in significant additional codegen
-dictionary_expressions = ["arrow/dyn_cmp_dict", "arrow/dyn_arith_dict"]
+dictionary_expressions = ["arrow/dyn_cmp_dict"]
regex_expressions = ["regex"]
unicode_expressions = ["unicode-segmentation"]
diff --git a/datafusion/physical-expr/src/expressions/binary.rs b/datafusion/physical-expr/src/expressions/binary.rs
index a21ae21ca9..71e26ee45b 100644
--- a/datafusion/physical-expr/src/expressions/binary.rs
+++ b/datafusion/physical-expr/src/expressions/binary.rs
@@ -1351,15 +1351,20 @@ mod tests {
use datafusion_common::{ColumnStatistics, Result, Statistics};
use datafusion_expr::type_coercion::binary::get_input_types;
- // Create a binary expression without coercion. Used here when we do not want to coerce the expressions
- // to valid types. Usage can result in an execution (after plan) error.
- fn binary_simple(
- l: Arc<dyn PhysicalExpr>,
+ /// Performs a binary operation, applying any type coercion necessary
+ fn binary_op(
+ left: Arc<dyn PhysicalExpr>,
op: Operator,
- r: Arc<dyn PhysicalExpr>,
- input_schema: &Schema,
- ) -> Arc<dyn PhysicalExpr> {
- binary(l, op, r, input_schema).unwrap()
+ right: Arc<dyn PhysicalExpr>,
+ schema: &Schema,
+ ) -> Result<Arc<dyn PhysicalExpr>> {
+ let left_type = left.data_type(schema)?;
+ let right_type = right.data_type(schema)?;
+ let (lhs, rhs) = get_input_types(&left_type, &op, &right_type)?;
+
+ let left_expr = try_cast(left, schema, lhs)?;
+ let right_expr = try_cast(right, schema, rhs)?;
+ binary(left_expr, op, right_expr, schema)
}
#[test]
@@ -1372,12 +1377,12 @@ mod tests {
let b = Int32Array::from(vec![1, 2, 4, 8, 16]);
// expression: "a < b"
- let lt = binary_simple(
+ let lt = binary(
col("a", &schema)?,
Operator::Lt,
col("b", &schema)?,
&schema,
- );
+ )?;
let batch =
RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a), Arc::new(b)])?;
@@ -1404,22 +1409,22 @@ mod tests {
let b = Int32Array::from(vec![2, 5, 4, 8, 8]);
// expression: "a < b OR a == b"
- let expr = binary_simple(
- binary_simple(
+ let expr = binary(
+ binary(
col("a", &schema)?,
Operator::Lt,
col("b", &schema)?,
&schema,
- ),
+ )?,
Operator::Or,
- binary_simple(
+ binary(
col("a", &schema)?,
Operator::Eq,
col("b", &schema)?,
&schema,
- ),
+ )?,
&schema,
- );
+ )?;
let batch =
RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a), Arc::new(b)])?;
@@ -1492,7 +1497,7 @@ mod tests {
}
#[test]
- fn test_type_coersion() -> Result<()> {
+ fn test_type_coercion() -> Result<()> {
test_coercion!(
Int32Array,
DataType::Int32,
@@ -1814,8 +1819,7 @@ mod tests {
// is no way at the time of this writing to create a dictionary
// array using the `From` trait
#[test]
- #[cfg(feature = "dictionary_expressions")]
- fn test_dictionary_type_to_array_coersion() -> Result<()> {
+ fn test_dictionary_type_to_array_coercion() -> Result<()> {
// Test string a string dictionary
let dict_type =
DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8));
@@ -1878,7 +1882,6 @@ mod tests {
}
#[test]
- #[cfg(feature = "dictionary_expressions")]
fn plus_op_dict() -> Result<()> {
let schema = Schema::new(vec![
Field::new(
@@ -1912,7 +1915,6 @@ mod tests {
}
#[test]
- #[cfg(feature = "dictionary_expressions")]
fn plus_op_dict_decimal() -> Result<()> {
let schema = Schema::new(vec![
Field::new(
@@ -2096,7 +2098,6 @@ mod tests {
}
#[test]
- #[cfg(feature = "dictionary_expressions")]
fn minus_op_dict() -> Result<()> {
let schema = Schema::new(vec![
Field::new(
@@ -2130,7 +2131,6 @@ mod tests {
}
#[test]
- #[cfg(feature = "dictionary_expressions")]
fn minus_op_dict_decimal() -> Result<()> {
let schema = Schema::new(vec![
Field::new(
@@ -2306,7 +2306,6 @@ mod tests {
}
#[test]
- #[cfg(feature = "dictionary_expressions")]
fn multiply_op_dict() -> Result<()> {
let schema = Schema::new(vec![
Field::new(
@@ -2340,7 +2339,6 @@ mod tests {
}
#[test]
- #[cfg(feature = "dictionary_expressions")]
fn multiply_op_dict_decimal() -> Result<()> {
let schema = Schema::new(vec![
Field::new(
@@ -2514,7 +2512,6 @@ mod tests {
}
#[test]
- #[cfg(feature = "dictionary_expressions")]
fn divide_op_dict() -> Result<()> {
let schema = Schema::new(vec![
Field::new(
@@ -2554,7 +2551,6 @@ mod tests {
}
#[test]
- #[cfg(feature = "dictionary_expressions")]
fn divide_op_dict_decimal() -> Result<()> {
let schema = Schema::new(vec![
Field::new(
@@ -2740,7 +2736,6 @@ mod tests {
}
#[test]
- #[cfg(feature = "dictionary_expressions")]
fn modulus_op_dict() -> Result<()> {
let schema = Schema::new(vec![
Field::new(
@@ -2780,7 +2775,6 @@ mod tests {
}
#[test]
- #[cfg(feature = "dictionary_expressions")]
fn modulus_op_dict_decimal() -> Result<()> {
let schema = Schema::new(vec![
Field::new(
@@ -2937,7 +2931,7 @@ mod tests {
expected: PrimitiveArray<T>,
) -> Result<()> {
let arithmetic_op =
- binary_simple(col("a", &schema)?, op, col("b", &schema)?, &schema);
+ binary_op(col("a", &schema)?, op, col("b", &schema)?, &schema)?;
let batch = RecordBatch::try_new(schema, data)?;
let result = arithmetic_op.evaluate(&batch)?.into_array(batch.num_rows());
@@ -2953,7 +2947,7 @@ mod tests {
expected: ArrayRef,
) -> Result<()> {
let lit = Arc::new(Literal::new(literal));
- let arithmetic_op = binary_simple(col("a", &schema)?, op, lit, &schema);
+ let arithmetic_op = binary_op(col("a", &schema)?, op, lit, &schema)?;
let batch = RecordBatch::try_new(schema, data)?;
let result = arithmetic_op.evaluate(&batch)?.into_array(batch.num_rows());
@@ -2968,16 +2962,10 @@ mod tests {
op: Operator,
expected: BooleanArray,
) -> Result<()> {
- let left_type = left.data_type();
- let right_type = right.data_type();
- let (lhs, rhs) = get_input_types(left_type, &op, right_type)?;
-
- let left_expr = try_cast(col("a", schema)?, schema, lhs)?;
- let right_expr = try_cast(col("b", schema)?, schema, rhs)?;
- let arithmetic_op = binary_simple(left_expr, op, right_expr, schema);
+ let op = binary_op(col("a", schema)?, op, col("b", schema)?, schema)?;
let data: Vec<ArrayRef> = vec![left.clone(), right.clone()];
let batch = RecordBatch::try_new(schema.clone(), data)?;
- let result = arithmetic_op.evaluate(&batch)?.into_array(batch.num_rows());
+ let result = op.evaluate(&batch)?.into_array(batch.num_rows());
assert_eq!(result.as_ref(), &expected);
Ok(())
@@ -2992,14 +2980,9 @@ mod tests {
expected: &BooleanArray,
) -> Result<()> {
let scalar = lit(scalar.clone());
- let (lhs, rhs) =
- get_input_types(&scalar.data_type(schema)?, &op, arr.data_type())?;
- let left_expr = try_cast(scalar, schema, lhs)?;
- let right_expr = try_cast(col("a", schema)?, schema, rhs)?;
-
- let arithmetic_op = binary_simple(left_expr, op, right_expr, schema);
+ let op = binary_op(scalar, op, col("a", schema)?, schema)?;
let batch = RecordBatch::try_new(Arc::clone(schema), vec![Arc::clone(arr)])?;
- let result = arithmetic_op.evaluate(&batch)?.into_array(batch.num_rows());
+ let result = op.evaluate(&batch)?.into_array(batch.num_rows());
assert_eq!(result.as_ref(), expected);
Ok(())
@@ -3014,14 +2997,9 @@ mod tests {
expected: &BooleanArray,
) -> Result<()> {
let scalar = lit(scalar.clone());
- let (lhs, rhs) =
- get_input_types(arr.data_type(), &op, &scalar.data_type(schema)?)?;
- let left_expr = try_cast(col("a", schema)?, schema, lhs)?;
- let right_expr = try_cast(scalar, schema, rhs)?;
-
- let arithmetic_op = binary_simple(left_expr, op, right_expr, schema);
+ let op = binary_op(col("a", schema)?, op, scalar, schema)?;
let batch = RecordBatch::try_new(Arc::clone(schema), vec![Arc::clone(arr)])?;
- let result = arithmetic_op.evaluate(&batch)?.into_array(batch.num_rows());
+ let result = op.evaluate(&batch)?.into_array(batch.num_rows());
assert_eq!(result.as_ref(), expected);
Ok(())
@@ -3587,7 +3565,7 @@ mod tests {
let tree_depth: i32 = 100;
let expr = (0..tree_depth)
.map(|_| col("a", schema.as_ref()).unwrap())
- .reduce(|l, r| binary_simple(l, Operator::Plus, r, &schema))
+ .reduce(|l, r| binary(l, Operator::Plus, r, &schema).unwrap())
.unwrap();
let result = expr
@@ -4069,26 +4047,7 @@ mod tests {
op: Operator,
expected: ArrayRef,
) -> Result<()> {
- let (lhs_type, rhs_type) =
- get_input_types(left.data_type(), &op, right.data_type()).unwrap();
-
- let left_expr = try_cast(col("a", schema)?, schema, lhs_type.clone())?;
- let right_expr = try_cast(col("b", schema)?, schema, rhs_type.clone())?;
-
- let coerced_schema = Schema::new(vec![
- Field::new(
- schema.field(0).name(),
- lhs_type,
- schema.field(0).is_nullable(),
- ),
- Field::new(
- schema.field(1).name(),
- rhs_type,
- schema.field(1).is_nullable(),
- ),
- ]);
-
- let arithmetic_op = binary_simple(left_expr, op, right_expr, &coerced_schema);
+ let arithmetic_op = binary_op(col("a", schema)?, op, col("b", schema)?, schema)?;
let data: Vec<ArrayRef> = vec![left.clone(), right.clone()];
let batch = RecordBatch::try_new(schema.clone(), data)?;
let result = arithmetic_op.evaluate(&batch)?.into_array(batch.num_rows());
@@ -4761,12 +4720,12 @@ mod tests {
// expression: "a >= 25"
let a = col("a", &schema).unwrap();
- let gt = binary_simple(
+ let gt = binary(
a.clone(),
Operator::GtEq,
lit(ScalarValue::from(25)),
&schema,
- );
+ )?;
let context = AnalysisContext::from_statistics(&schema, &statistics);
let predicate_boundaries = gt
@@ -4790,12 +4749,12 @@ mod tests {
// expression: "50 >= a"
let a = col("a", &schema).unwrap();
- let gt = binary_simple(
+ let gt = binary(
lit(ScalarValue::from(50)),
Operator::GtEq,
a.clone(),
&schema,
- );
+ )?;
let context = AnalysisContext::from_statistics(&schema, &statistics);
let predicate_boundaries = gt
diff --git a/datafusion/physical-expr/src/expressions/binary/kernels_arrow.rs b/datafusion/physical-expr/src/expressions/binary/kernels_arrow.rs
index 1f59fcd141..e7d7f62c86 100644
--- a/datafusion/physical-expr/src/expressions/binary/kernels_arrow.rs
+++ b/datafusion/physical-expr/src/expressions/binary/kernels_arrow.rs
@@ -20,23 +20,17 @@
use arrow::compute::{
add_dyn, add_scalar_dyn, divide_dyn_checked, divide_scalar_dyn, modulus_dyn,
- modulus_scalar_dyn, multiply_dyn, multiply_fixed_point, multiply_scalar_dyn,
- subtract_dyn, subtract_scalar_dyn, try_unary,
+ modulus_scalar_dyn, multiply_fixed_point, multiply_scalar_dyn, subtract_dyn,
+ subtract_scalar_dyn, try_unary,
};
-use arrow::datatypes::{
- i256, Date32Type, Date64Type, Decimal128Type, Decimal256Type,
- DECIMAL128_MAX_PRECISION,
-};
-use arrow::{array::*, datatypes::ArrowNumericType, downcast_dictionary_array};
-use arrow_array::types::{ArrowDictionaryKeyType, DecimalType};
+use arrow::datatypes::{Date32Type, Date64Type, Decimal128Type};
+use arrow::{array::*, datatypes::ArrowNumericType};
use arrow_array::ArrowNativeTypeOp;
-use arrow_buffer::ArrowNativeType;
use arrow_schema::{DataType, IntervalUnit};
use chrono::{Days, Duration, Months, NaiveDate, NaiveDateTime};
use datafusion_common::cast::{as_date32_array, as_date64_array, as_decimal128_array};
use datafusion_common::scalar::{date32_op, date64_op};
use datafusion_common::{DataFusionError, Result, ScalarValue};
-use std::cmp::min;
use std::ops::Add;
use std::sync::Arc;
@@ -641,20 +635,6 @@ fn decimal_array_with_precision_scale(
Arc::new(array.clone().with_precision_and_scale(precision, scale)?)
as ArrayRef
}
- DataType::Dictionary(_, _) => {
- downcast_dictionary_array!(
- array => match array.values().data_type() {
- DataType::Decimal128(_, _) => {
- let decimal_dict_array = array.downcast_dict::<Decimal128Array>().unwrap();
- let decimal_array = decimal_dict_array.values().clone();
- let decimal_array = decimal_array.with_precision_and_scale(precision, scale)?;
- Arc::new(array.with_values(&decimal_array)) as ArrayRef
- }
- t => return Err(DataFusionError::Internal(format!("Unexpected dictionary value type {t}"))),
- },
- t => return Err(DataFusionError::Internal(format!("Unexpected datatype {t}"))),
- )
- }
_ => {
return Err(DataFusionError::Internal(
"Unexpected data type".to_string(),
@@ -698,75 +678,6 @@ pub(crate) fn subtract_dyn_decimal(
decimal_array_with_precision_scale(array, precision, scale)
}
-/// Remove this once arrow-rs provides `multiply_fixed_point_dyn`.
-fn math_op_dict<K, T, F>(
- left: &DictionaryArray<K>,
- right: &DictionaryArray<K>,
- op: F,
-) -> Result<PrimitiveArray<T>>
-where
- K: ArrowDictionaryKeyType + ArrowNumericType,
- T: ArrowNumericType,
- F: Fn(T::Native, T::Native) -> T::Native,
-{
- if left.len() != right.len() {
- return Err(DataFusionError::Internal(format!(
- "Cannot perform operation on arrays of different length ({}, {})",
- left.len(),
- right.len()
- )));
- }
-
- // Safety justification: Since the inputs are valid Arrow arrays, all values are
- // valid indexes into the dictionary (which is verified during construction)
-
- let left_iter = unsafe {
- left.values()
- .as_primitive::<T>()
- .take_iter_unchecked(left.keys_iter())
- };
-
- let right_iter = unsafe {
- right
- .values()
- .as_primitive::<T>()
- .take_iter_unchecked(right.keys_iter())
- };
-
- let result = left_iter
- .zip(right_iter)
- .map(|(left_value, right_value)| {
- if let (Some(left), Some(right)) = (left_value, right_value) {
- Some(op(left, right))
- } else {
- None
- }
- })
- .collect();
-
- Ok(result)
-}
-
-/// Divide a decimal native value by given divisor and round the result.
-/// Remove this once arrow-rs provides `multiply_fixed_point_dyn`.
-fn divide_and_round<I>(input: I::Native, div: I::Native) -> I::Native
-where
- I: DecimalType,
- I::Native: ArrowNativeTypeOp,
-{
- let d = input.div_wrapping(div);
- let r = input.mod_wrapping(div);
-
- let half = div.div_wrapping(I::Native::from_usize(2).unwrap());
- let half_neg = half.neg_wrapping();
- // Round result
- match input >= I::Native::ZERO {
- true if r >= half => d.add_wrapping(I::Native::ONE),
- false if r <= half_neg => d.sub_wrapping(I::Native::ONE),
- _ => d,
- }
-}
-
/// Remove this once arrow-rs provides `multiply_fixed_point_dyn`.
/// <https://github.com/apache/arrow-rs/issues/4135>
fn multiply_fixed_point_dyn(
@@ -775,56 +686,9 @@ fn multiply_fixed_point_dyn(
required_scale: i8,
) -> Result<ArrayRef> {
match (left.data_type(), right.data_type()) {
- (
- DataType::Dictionary(_, lhs_value_type),
- DataType::Dictionary(_, rhs_value_type),
- ) if matches!(lhs_value_type.as_ref(), &DataType::Decimal128(_, _))
- && matches!(rhs_value_type.as_ref(), &DataType::Decimal128(_, _)) =>
- {
- downcast_dictionary_array!(
- left => match left.values().data_type() {
- DataType::Decimal128(_, _) => {
- let lhs_precision_scale = get_precision_scale(lhs_value_type.as_ref())?;
- let rhs_precision_scale = get_precision_scale(rhs_value_type.as_ref())?;
-
- let product_scale = lhs_precision_scale.1 + rhs_precision_scale.1;
- let precision = min(lhs_precision_scale.0 + rhs_precision_scale.0 + 1, DECIMAL128_MAX_PRECISION);
-
- if required_scale == product_scale {
- return Ok(multiply_dyn(left, right)?.as_primitive::<Decimal128Type>().clone()
- .with_precision_and_scale(precision, required_scale).map(|a| Arc::new(a) as ArrayRef)?);
- }
-
- if required_scale > product_scale {
- return Err(DataFusionError::Internal(format!(
- "Required scale {required_scale} is greater than product scale {product_scale}"
- )));
- }
-
- let divisor =
- i256::from_i128(10).pow_wrapping((product_scale - required_scale) as u32);
-
- let right = as_dictionary_array::<_>(right);
-
- let array = math_op_dict::<_, Decimal128Type, _>(left, right, |a, b| {
- let a = i256::from_i128(a);
- let b = i256::from_i128(b);
-
- let mut mul = a.wrapping_mul(b);
- mul = divide_and_round::<Decimal256Type>(mul, divisor);
- mul.as_i128()
- }).map(|a| a.with_precision_and_scale(precision, required_scale).unwrap())?;
-
- Ok(Arc::new(array))
- }
- t => unreachable!("Unsupported dictionary value type {}", t),
- },
- t => unreachable!("Unsupported data type {}", t),
- )
- }
(DataType::Decimal128(_, _), DataType::Decimal128(_, _)) => {
- let left = left.as_any().downcast_ref::<Decimal128Array>().unwrap();
- let right = right.as_any().downcast_ref::<Decimal128Array>().unwrap();
+ let left = left.as_primitive::<Decimal128Type>();
+ let right = right.as_primitive::<Decimal128Type>();
Ok(multiply_fixed_point(left, right, required_scale)
.map(|a| Arc::new(a) as ArrayRef)?)
@@ -2525,10 +2389,11 @@ mod tests {
"1234567890.0000000000000000000000000000"
);
- // [123456789, 10]
+ // [123456789, 10, 10]
let a = Decimal128Array::from(vec![
123456789000000000000000000,
10000000000000000000,
+ 10000000000000000000,
])
.with_precision_and_scale(38, 18)
.unwrap();
@@ -2542,18 +2407,12 @@ mod tests {
.with_precision_and_scale(38, 18)
.unwrap();
- let keys = Int8Array::from(vec![Some(0_i8), Some(1), Some(1), None]);
- let array1 = DictionaryArray::new(keys, Arc::new(a));
- let keys = Int8Array::from(vec![Some(0_i8), Some(1), Some(2), None]);
- let array2 = DictionaryArray::new(keys, Arc::new(b));
-
- let result = multiply_fixed_point_dyn(&array1, &array2, 28).unwrap();
+ let result = multiply_fixed_point_dyn(&a, &b, 28).unwrap();
let expected = Arc::new(
Decimal128Array::from(vec![
Some(12345678900000000000000000000000000000),
Some(12345678900000000000000000000000000000),
Some(1200000000000000000000000000000),
- None,
])
.with_precision_and_scale(38, 28)
.unwrap(),