You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by al...@apache.org on 2022/12/05 17:48:04 UTC
[arrow-datafusion] branch master updated: Improve error handling for array downcasting (#4493)
This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/master by this push:
new 5bb9a1856 Improve error handling for array downcasting (#4493)
5bb9a1856 is described below
commit 5bb9a18563fad4dc9e18b7381c2542e32144f2f1
Author: Burak <bu...@gmail.com>
AuthorDate: Mon Dec 5 20:47:58 2022 +0300
Improve error handling for array downcasting (#4493)
* improve error handling and add some more types
* refactor booleanarray
---
datafusion/common/src/cast.rs | 39 ++++++++++--
datafusion/common/src/scalar.rs | 11 ++--
datafusion/core/src/datasource/listing/helpers.rs | 13 +---
.../core/src/physical_plan/file_format/json.rs | 19 ++----
datafusion/core/src/physical_plan/filter.rs | 12 +---
datafusion/physical-expr/src/crypto_expressions.rs | 19 ++----
.../physical-expr/src/datetime_expressions.rs | 24 +++-----
.../physical-expr/src/expressions/datetime.rs | 9 +--
datafusion/physical-expr/src/functions.rs | 11 ++--
datafusion/physical-expr/src/regex_expressions.rs | 51 ++++++---------
datafusion/physical-expr/src/string_expressions.rs | 59 +++++++-----------
.../physical-expr/src/unicode_expressions.rs | 72 +++++++---------------
datafusion/row/src/writer.rs | 6 +-
13 files changed, 136 insertions(+), 209 deletions(-)
diff --git a/datafusion/common/src/cast.rs b/datafusion/common/src/cast.rs
index d22e2fb01..5e0bfbbd6 100644
--- a/datafusion/common/src/cast.rs
+++ b/datafusion/common/src/cast.rs
@@ -23,12 +23,13 @@
use crate::{downcast_value, DataFusionError};
use arrow::{
array::{
- Array, BinaryArray, BooleanArray, Date32Array, Decimal128Array, DictionaryArray,
- Float32Array, Float64Array, GenericBinaryArray, GenericListArray, Int32Array,
- Int64Array, LargeListArray, ListArray, MapArray, NullArray, OffsetSizeTrait,
- PrimitiveArray, StringArray, StructArray, TimestampMicrosecondArray,
- TimestampMillisecondArray, TimestampNanosecondArray, TimestampSecondArray,
- UInt32Array, UInt64Array, UnionArray,
+ Array, BinaryArray, BooleanArray, Date32Array, Date64Array, Decimal128Array,
+ DictionaryArray, FixedSizeBinaryArray, FixedSizeListArray, Float32Array,
+ Float64Array, GenericBinaryArray, GenericListArray, GenericStringArray,
+ Int32Array, Int64Array, LargeListArray, ListArray, MapArray, NullArray,
+ OffsetSizeTrait, PrimitiveArray, StringArray, StructArray,
+ TimestampMicrosecondArray, TimestampMillisecondArray, TimestampNanosecondArray,
+ TimestampSecondArray, UInt32Array, UInt64Array, UnionArray,
},
datatypes::{ArrowDictionaryKeyType, ArrowPrimitiveType},
};
@@ -177,3 +178,29 @@ pub fn as_timestamp_second_array(
pub fn as_binary_array(array: &dyn Array) -> Result<&BinaryArray, DataFusionError> {
Ok(downcast_value!(array, BinaryArray))
}
+
+// Downcast ArrayRef to FixedSizeListArray
+pub fn as_fixed_size_list_array(
+ array: &dyn Array,
+) -> Result<&FixedSizeListArray, DataFusionError> {
+ Ok(downcast_value!(array, FixedSizeListArray))
+}
+
+// Downcast ArrayRef to FixedSizeListArray
+pub fn as_fixed_size_binary_array(
+ array: &dyn Array,
+) -> Result<&FixedSizeBinaryArray, DataFusionError> {
+ Ok(downcast_value!(array, FixedSizeBinaryArray))
+}
+
+// Downcast ArrayRef to Date64Array
+pub fn as_date64_array(array: &dyn Array) -> Result<&Date64Array, DataFusionError> {
+ Ok(downcast_value!(array, Date64Array))
+}
+
+// Downcast ArrayRef to GenericBinaryArray
+pub fn as_generic_string_array<T: OffsetSizeTrait>(
+ array: &dyn Array,
+) -> Result<&GenericStringArray<T>, DataFusionError> {
+ Ok(downcast_value!(array, GenericStringArray, T))
+}
diff --git a/datafusion/common/src/scalar.rs b/datafusion/common/src/scalar.rs
index d7c5df065..46a1f16f7 100644
--- a/datafusion/common/src/scalar.rs
+++ b/datafusion/common/src/scalar.rs
@@ -26,7 +26,8 @@ use std::str::FromStr;
use std::{convert::TryFrom, fmt, iter::repeat, sync::Arc};
use crate::cast::{
- as_decimal128_array, as_dictionary_array, as_list_array, as_struct_array,
+ as_decimal128_array, as_dictionary_array, as_fixed_size_binary_array,
+ as_fixed_size_list_array, as_list_array, as_struct_array,
};
use crate::delta::shift_months;
use crate::error::{DataFusionError, Result};
@@ -2109,8 +2110,7 @@ impl ScalarValue {
Self::Struct(Some(field_values), Box::new(fields.clone()))
}
DataType::FixedSizeList(nested_type, _len) => {
- let list_array =
- array.as_any().downcast_ref::<FixedSizeListArray>().unwrap();
+ let list_array = as_fixed_size_list_array(array)?;
let value = match list_array.is_null(index) {
true => None,
false => {
@@ -2124,10 +2124,7 @@ impl ScalarValue {
ScalarValue::new_list(value, nested_type.data_type().clone())
}
DataType::FixedSizeBinary(_) => {
- let array = array
- .as_any()
- .downcast_ref::<FixedSizeBinaryArray>()
- .unwrap();
+ let array = as_fixed_size_binary_array(array)?;
let size = match array.data_type() {
DataType::FixedSizeBinary(size) => *size,
_ => unreachable!(),
diff --git a/datafusion/core/src/datasource/listing/helpers.rs b/datafusion/core/src/datasource/listing/helpers.rs
index 3cfe9ec14..8b20fc5d6 100644
--- a/datafusion/core/src/datasource/listing/helpers.rs
+++ b/datafusion/core/src/datasource/listing/helpers.rs
@@ -21,10 +21,7 @@ use std::sync::Arc;
use arrow::array::new_empty_array;
use arrow::{
- array::{
- Array, ArrayBuilder, ArrayRef, Date64Array, Date64Builder, StringBuilder,
- UInt64Builder,
- },
+ array::{ArrayBuilder, ArrayRef, Date64Builder, StringBuilder, UInt64Builder},
datatypes::{DataType, Field, Schema},
record_batch::RecordBatch,
};
@@ -40,7 +37,7 @@ use crate::{
use super::PartitionedFile;
use crate::datasource::listing::ListingTableUrl;
use datafusion_common::{
- cast::{as_string_array, as_uint64_array},
+ cast::{as_date64_array, as_string_array, as_uint64_array},
Column, DataFusionError,
};
use datafusion_expr::{
@@ -341,11 +338,7 @@ fn batches_to_paths(batches: &[RecordBatch]) -> Result<Vec<PartitionedFile>> {
.flat_map(|batch| {
let key_array = as_string_array(batch.column(0)).unwrap();
let length_array = as_uint64_array(batch.column(1)).unwrap();
- let modified_array = batch
- .column(2)
- .as_any()
- .downcast_ref::<Date64Array>()
- .unwrap();
+ let modified_array = as_date64_array(batch.column(2)).unwrap();
(0..batch.num_rows()).map(move |row| {
Ok(PartitionedFile {
diff --git a/datafusion/core/src/physical_plan/file_format/json.rs b/datafusion/core/src/physical_plan/file_format/json.rs
index f301a7fdb..070e7aa54 100644
--- a/datafusion/core/src/physical_plan/file_format/json.rs
+++ b/datafusion/core/src/physical_plan/file_format/json.rs
@@ -261,6 +261,7 @@ mod tests {
use crate::prelude::NdJsonReadOptions;
use crate::prelude::*;
use crate::test::partitioned_file_groups;
+ use datafusion_common::cast::{as_int32_array, as_int64_array};
use rstest::*;
use tempfile::TempDir;
use url::Url;
@@ -362,11 +363,7 @@ mod tests {
let batch = it.next().await.unwrap()?;
assert_eq!(batch.num_rows(), 3);
- let values = batch
- .column(0)
- .as_any()
- .downcast_ref::<arrow::array::Int64Array>()
- .unwrap();
+ let values = as_int64_array(batch.column(0))?;
assert_eq!(values.value(0), 1);
assert_eq!(values.value(1), -10);
assert_eq!(values.value(2), 2);
@@ -416,11 +413,7 @@ mod tests {
let batch = it.next().await.unwrap()?;
assert_eq!(batch.num_rows(), 3);
- let values = batch
- .column(missing_field_idx)
- .as_any()
- .downcast_ref::<arrow::array::Int32Array>()
- .unwrap();
+ let values = as_int32_array(batch.column(missing_field_idx))?;
assert_eq!(values.len(), 3);
assert!(values.is_null(0));
assert!(values.is_null(1));
@@ -471,11 +464,7 @@ mod tests {
let batch = it.next().await.unwrap()?;
assert_eq!(batch.num_rows(), 4);
- let values = batch
- .column(0)
- .as_any()
- .downcast_ref::<arrow::array::Int64Array>()
- .unwrap();
+ let values = as_int64_array(batch.column(0))?;
assert_eq!(values.value(0), 1);
assert_eq!(values.value(1), -10);
assert_eq!(values.value(2), 2);
diff --git a/datafusion/core/src/physical_plan/filter.rs b/datafusion/core/src/physical_plan/filter.rs
index ed48d7b7d..c0ee6da48 100644
--- a/datafusion/core/src/physical_plan/filter.rs
+++ b/datafusion/core/src/physical_plan/filter.rs
@@ -31,11 +31,11 @@ use crate::physical_plan::{
Column, DisplayFormatType, EquivalenceProperties, ExecutionPlan, Partitioning,
PhysicalExpr,
};
-use arrow::array::BooleanArray;
use arrow::compute::filter_record_batch;
use arrow::datatypes::{DataType, SchemaRef};
use arrow::error::Result as ArrowResult;
use arrow::record_batch::RecordBatch;
+use datafusion_common::cast::as_boolean_array;
use datafusion_expr::Operator;
use datafusion_physical_expr::expressions::BinaryExpr;
use datafusion_physical_expr::{split_conjunction, AnalysisContext};
@@ -217,15 +217,7 @@ fn batch_filter(
.map(|v| v.into_array(batch.num_rows()))
.map_err(DataFusionError::into)
.and_then(|array| {
- array
- .as_any()
- .downcast_ref::<BooleanArray>()
- .ok_or_else(|| {
- DataFusionError::Internal(
- "Filter predicate evaluated to non-boolean value".to_string(),
- )
- .into()
- })
+ Ok(as_boolean_array(&array)?)
// apply filter array to record batch
.and_then(|filter_array| filter_record_batch(batch, filter_array))
})
diff --git a/datafusion/physical-expr/src/crypto_expressions.rs b/datafusion/physical-expr/src/crypto_expressions.rs
index 07ebdbfd1..33806ac99 100644
--- a/datafusion/physical-expr/src/crypto_expressions.rs
+++ b/datafusion/physical-expr/src/crypto_expressions.rs
@@ -18,20 +18,19 @@
//! Crypto expressions
use arrow::{
- array::{
- Array, ArrayRef, BinaryArray, GenericStringArray, OffsetSizeTrait, StringArray,
- },
+ array::{Array, ArrayRef, BinaryArray, OffsetSizeTrait, StringArray},
datatypes::DataType,
};
use blake2::{Blake2b512, Blake2s256, Digest};
use blake3::Hasher as Blake3;
-use datafusion_common::cast::{as_binary_array, as_generic_binary_array};
+use datafusion_common::cast::{
+ as_binary_array, as_generic_binary_array, as_generic_string_array,
+};
use datafusion_common::ScalarValue;
use datafusion_common::{DataFusionError, Result};
use datafusion_expr::ColumnarValue;
use md5::Md5;
use sha2::{Sha224, Sha256, Sha384, Sha512};
-use std::any::type_name;
use std::fmt::Write;
use std::sync::Arc;
use std::{fmt, str::FromStr};
@@ -167,15 +166,7 @@ impl DigestAlgorithm {
where
T: OffsetSizeTrait,
{
- let input_value = value
- .as_any()
- .downcast_ref::<GenericStringArray<T>>()
- .ok_or_else(|| {
- DataFusionError::Internal(format!(
- "could not cast value to {}",
- type_name::<GenericStringArray<T>>()
- ))
- })?;
+ let input_value = as_generic_string_array::<T>(value)?;
let array: ArrayRef = match self {
Self::Md5 => digest_to_array!(Md5, input_value),
Self::Sha224 => digest_to_array!(Sha224, input_value),
diff --git a/datafusion/physical-expr/src/datetime_expressions.rs b/datafusion/physical-expr/src/datetime_expressions.rs
index 1f54cc672..ca973c16b 100644
--- a/datafusion/physical-expr/src/datetime_expressions.rs
+++ b/datafusion/physical-expr/src/datetime_expressions.rs
@@ -18,23 +18,22 @@
//! DateTime expressions
use arrow::{
- array::{Array, ArrayRef, GenericStringArray, OffsetSizeTrait, PrimitiveArray},
+ array::TimestampNanosecondArray, compute::kernels::temporal, datatypes::TimeUnit,
+ temporal_conversions::timestamp_ns_to_datetime,
+};
+use arrow::{
+ array::{Array, ArrayRef, OffsetSizeTrait, PrimitiveArray},
compute::kernels::cast_utils::string_to_timestamp_nanos,
datatypes::{
ArrowPrimitiveType, DataType, IntervalDayTimeType, TimestampMicrosecondType,
TimestampMillisecondType, TimestampNanosecondType, TimestampSecondType,
},
};
-use arrow::{
- array::{Date64Array, TimestampNanosecondArray},
- compute::kernels::temporal,
- datatypes::TimeUnit,
- temporal_conversions::timestamp_ns_to_datetime,
-};
use chrono::prelude::*;
use chrono::Duration;
use datafusion_common::cast::{
- as_date32_array, as_timestamp_microsecond_array, as_timestamp_millisecond_array,
+ as_date32_array, as_date64_array, as_generic_string_array,
+ as_timestamp_microsecond_array, as_timestamp_millisecond_array,
as_timestamp_nanosecond_array, as_timestamp_second_array,
};
use datafusion_common::{DataFusionError, Result};
@@ -69,12 +68,7 @@ where
)));
}
- let array = args[0]
- .as_any()
- .downcast_ref::<GenericStringArray<T>>()
- .ok_or_else(|| {
- DataFusionError::Internal("failed to downcast to string".to_string())
- })?;
+ let array = as_generic_string_array::<T>(args[0])?;
// first map is the iterator, second is for the `Option<_>`
array
@@ -412,7 +406,7 @@ macro_rules! extract_date_part {
Err(e) => Err(e),
},
DataType::Date64 => {
- let array = $ARRAY.as_any().downcast_ref::<Date64Array>().unwrap();
+ let array = as_date64_array($ARRAY)?;
Ok($FN(array)?)
}
DataType::Timestamp(time_unit, None) => match time_unit {
diff --git a/datafusion/physical-expr/src/expressions/datetime.rs b/datafusion/physical-expr/src/expressions/datetime.rs
index bb29709aa..a43eed090 100644
--- a/datafusion/physical-expr/src/expressions/datetime.rs
+++ b/datafusion/physical-expr/src/expressions/datetime.rs
@@ -17,7 +17,7 @@
use crate::physical_expr::down_cast_any_ref;
use crate::PhysicalExpr;
-use arrow::array::{Array, ArrayRef, Date64Array};
+use arrow::array::{Array, ArrayRef};
use arrow::compute::unary;
use arrow::datatypes::{
DataType, Date32Type, Date64Type, Schema, TimeUnit, TimestampMicrosecondType,
@@ -25,8 +25,9 @@ use arrow::datatypes::{
};
use arrow::record_batch::RecordBatch;
use datafusion_common::cast::{
- as_date32_array, as_timestamp_microsecond_array, as_timestamp_millisecond_array,
- as_timestamp_nanosecond_array, as_timestamp_second_array,
+ as_date32_array, as_date64_array, as_timestamp_microsecond_array,
+ as_timestamp_millisecond_array, as_timestamp_nanosecond_array,
+ as_timestamp_second_array,
};
use datafusion_common::scalar::{
date32_add, date64_add, microseconds_add, milliseconds_add, nanoseconds_add,
@@ -194,7 +195,7 @@ pub fn evaluate_array(
})) as ArrayRef
}
DataType::Date64 => {
- let array = array.as_any().downcast_ref::<Date64Array>().unwrap();
+ let array = as_date64_array(&array)?;
Arc::new(unary::<Date64Type, _, Date64Type>(array, |ms| {
date64_add(ms, scalar, sign).unwrap()
})) as ArrayRef
diff --git a/datafusion/physical-expr/src/functions.rs b/datafusion/physical-expr/src/functions.rs
index 1ed83b89a..81a6b3af1 100644
--- a/datafusion/physical-expr/src/functions.rs
+++ b/datafusion/physical-expr/src/functions.rs
@@ -778,13 +778,13 @@ mod tests {
use crate::type_coercion::coerce;
use arrow::{
array::{
- Array, ArrayRef, BinaryArray, BooleanArray, FixedSizeListArray, Float32Array,
- Float64Array, Int32Array, StringArray, UInt32Array, UInt64Array,
+ Array, ArrayRef, BinaryArray, BooleanArray, Float32Array, Float64Array,
+ Int32Array, StringArray, UInt32Array, UInt64Array,
},
datatypes::Field,
record_batch::RecordBatch,
};
- use datafusion_common::cast::as_uint64_array;
+ use datafusion_common::cast::{as_fixed_size_list_array, as_uint64_array};
use datafusion_common::{Result, ScalarValue};
/// $FUNC function to test
@@ -2807,10 +2807,7 @@ mod tests {
let result = expr.evaluate(&batch)?.into_array(batch.num_rows());
// downcast works
- let result = result
- .as_any()
- .downcast_ref::<FixedSizeListArray>()
- .unwrap();
+ let result = as_fixed_size_list_array(&result)?;
// value is correct
assert_eq!(format!("{:?}", result.value(0)), expected);
diff --git a/datafusion/physical-expr/src/regex_expressions.rs b/datafusion/physical-expr/src/regex_expressions.rs
index bdf61d5d0..68c2c20e7 100644
--- a/datafusion/physical-expr/src/regex_expressions.rs
+++ b/datafusion/physical-expr/src/regex_expressions.rs
@@ -26,12 +26,11 @@ use arrow::array::{
OffsetSizeTrait,
};
use arrow::compute;
-use datafusion_common::{DataFusionError, Result};
+use datafusion_common::{cast::as_generic_string_array, DataFusionError, Result};
use datafusion_expr::{ColumnarValue, ScalarFunctionImplementation};
use hashbrown::HashMap;
use lazy_static::lazy_static;
use regex::Regex;
-use std::any::type_name;
use std::sync::Arc;
use crate::functions::{make_scalar_function, make_scalar_function_with_hints, Hint};
@@ -42,7 +41,7 @@ use crate::functions::{make_scalar_function, make_scalar_function_with_hints, Hi
/// then calls the given early abort function.
macro_rules! fetch_string_arg {
($ARG:expr, $NAME:expr, $T:ident, $EARLY_ABORT:ident) => {{
- let array = downcast_string_array_arg!($ARG, $NAME, $T);
+ let array = as_generic_string_array::<T>($ARG)?;
if array.len() == 0 || array.is_null(0) {
return $EARLY_ABORT(array);
} else {
@@ -51,32 +50,18 @@ macro_rules! fetch_string_arg {
}};
}
-macro_rules! downcast_string_array_arg {
- ($ARG:expr, $NAME:expr, $T:ident) => {{
- $ARG.as_any()
- .downcast_ref::<GenericStringArray<T>>()
- .ok_or_else(|| {
- DataFusionError::Internal(format!(
- "could not cast {} to {}",
- $NAME,
- type_name::<GenericStringArray<T>>()
- ))
- })?
- }};
-}
-
/// extract a specific group from a string column, using a regular expression
pub fn regexp_match<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> {
match args.len() {
2 => {
- let values = downcast_string_array_arg!(args[0], "string", T);
- let regex = downcast_string_array_arg!(args[1], "pattern", T);
+ let values = as_generic_string_array::<T>(&args[0])?;
+ let regex = as_generic_string_array::<T>(&args[1])?;
compute::regexp_match(values, regex, None).map_err(DataFusionError::ArrowError)
}
3 => {
- let values = downcast_string_array_arg!(args[0], "string", T);
- let regex = downcast_string_array_arg!(args[1], "pattern", T);
- let flags = Some(downcast_string_array_arg!(args[2], "flags", T));
+ let values = as_generic_string_array::<T>(&args[0])?;
+ let regex = as_generic_string_array::<T>(&args[1])?;
+ let flags = Some(as_generic_string_array::<T>(&args[2])?);
match flags {
Some(f) if f.iter().any(|s| s == Some("g")) => {
@@ -115,9 +100,9 @@ pub fn regexp_replace<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef>
match args.len() {
3 => {
- let string_array = downcast_string_array_arg!(args[0], "string", T);
- let pattern_array = downcast_string_array_arg!(args[1], "pattern", T);
- let replacement_array = downcast_string_array_arg!(args[2], "replacement", T);
+ let string_array = as_generic_string_array::<T>(&args[0])?;
+ let pattern_array = as_generic_string_array::<T>(&args[1])?;
+ let replacement_array = as_generic_string_array::<T>(&args[2])?;
let result = string_array
.iter()
@@ -150,10 +135,10 @@ pub fn regexp_replace<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef>
Ok(Arc::new(result) as ArrayRef)
}
4 => {
- let string_array = downcast_string_array_arg!(args[0], "string", T);
- let pattern_array = downcast_string_array_arg!(args[1], "pattern", T);
- let replacement_array = downcast_string_array_arg!(args[2], "replacement", T);
- let flags_array = downcast_string_array_arg!(args[3], "flags", T);
+ let string_array = as_generic_string_array::<T>(&args[0])?;
+ let pattern_array = as_generic_string_array::<T>(&args[1])?;
+ let replacement_array = as_generic_string_array::<T>(&args[2])?;
+ let flags_array = as_generic_string_array::<T>(&args[3])?;
let result = string_array
.iter()
@@ -227,13 +212,13 @@ fn _regexp_replace_early_abort<T: OffsetSizeTrait>(
fn _regexp_replace_static_pattern_replace<T: OffsetSizeTrait>(
args: &[ArrayRef],
) -> Result<ArrayRef> {
- let string_array = downcast_string_array_arg!(args[0], "string", T);
- let pattern = fetch_string_arg!(args[1], "pattern", T, _regexp_replace_early_abort);
+ let string_array = as_generic_string_array::<T>(&args[0])?;
+ let pattern = fetch_string_arg!(&args[1], "pattern", T, _regexp_replace_early_abort);
let replacement =
- fetch_string_arg!(args[2], "replacement", T, _regexp_replace_early_abort);
+ fetch_string_arg!(&args[2], "replacement", T, _regexp_replace_early_abort);
let flags = match args.len() {
3 => None,
- 4 => Some(fetch_string_arg!(args[3], "flags", T, _regexp_replace_early_abort)),
+ 4 => Some(fetch_string_arg!(&args[3], "flags", T, _regexp_replace_early_abort)),
other => {
return Err(DataFusionError::Internal(format!(
"regexp_replace was called with {} arguments. It requires at least 3 and at most 4.",
diff --git a/datafusion/physical-expr/src/string_expressions.rs b/datafusion/physical-expr/src/string_expressions.rs
index 2327521d7..7048354f1 100644
--- a/datafusion/physical-expr/src/string_expressions.rs
+++ b/datafusion/physical-expr/src/string_expressions.rs
@@ -29,30 +29,17 @@ use arrow::{
datatypes::{ArrowNativeType, ArrowPrimitiveType, DataType},
};
use datafusion_common::{
- cast::{as_int64_array, as_primitive_array, as_string_array},
+ cast::{
+ as_generic_string_array, as_int64_array, as_primitive_array, as_string_array,
+ },
ScalarValue,
};
use datafusion_common::{DataFusionError, Result};
use datafusion_expr::ColumnarValue;
-use std::any::type_name;
use std::iter;
use std::sync::Arc;
use uuid::Uuid;
-macro_rules! downcast_string_arg {
- ($ARG:expr, $NAME:expr, $T:ident) => {{
- $ARG.as_any()
- .downcast_ref::<GenericStringArray<T>>()
- .ok_or_else(|| {
- DataFusionError::Internal(format!(
- "could not cast {} to {}",
- $NAME,
- type_name::<GenericStringArray<T>>()
- ))
- })?
- }};
-}
-
/// applies a unary expression to `args[0]` that is expected to be downcastable to
/// a `GenericStringArray` and returns a `GenericStringArray` (which may have a different offset)
/// # Errors
@@ -78,7 +65,7 @@ where
)));
}
- let string_array = downcast_string_arg!(args[0], "string", T);
+ let string_array = as_generic_string_array::<T>(args[0])?;
// first map is the iterator, second is for the `Option<_>`
Ok(string_array.iter().map(|string| string.map(&op)).collect())
@@ -136,7 +123,7 @@ where
/// Returns the numeric code of the first character of the argument.
/// ascii('x') = 120
pub fn ascii<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> {
- let string_array = downcast_string_arg!(args[0], "string", T);
+ let string_array = as_generic_string_array::<T>(&args[0])?;
let result = string_array
.iter()
@@ -156,7 +143,7 @@ pub fn ascii<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> {
pub fn btrim<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> {
match args.len() {
1 => {
- let string_array = downcast_string_arg!(args[0], "string", T);
+ let string_array = as_generic_string_array::<T>(&args[0])?;
let result = string_array
.iter()
@@ -170,8 +157,8 @@ pub fn btrim<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> {
Ok(Arc::new(result) as ArrayRef)
}
2 => {
- let string_array = downcast_string_arg!(args[0], "string", T);
- let characters_array = downcast_string_arg!(args[1], "characters", T);
+ let string_array = as_generic_string_array::<T>(&args[0])?;
+ let characters_array = as_generic_string_array::<T>(&args[1])?;
let result = string_array
.iter()
@@ -335,7 +322,7 @@ pub fn concat_ws(args: &[ArrayRef]) -> Result<ArrayRef> {
/// Converts the first letter of each word to upper case and the rest to lower case. Words are sequences of alphanumeric characters separated by non-alphanumeric characters.
/// initcap('hi THOMAS') = 'Hi Thomas'
pub fn initcap<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> {
- let string_array = downcast_string_arg!(args[0], "string", T);
+ let string_array = as_generic_string_array::<T>(&args[0])?;
// first map is the iterator, second is for the `Option<_>`
let result = string_array
@@ -373,7 +360,7 @@ pub fn lower(args: &[ColumnarValue]) -> Result<ColumnarValue> {
pub fn ltrim<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> {
match args.len() {
1 => {
- let string_array = downcast_string_arg!(args[0], "string", T);
+ let string_array = as_generic_string_array::<T>(&args[0])?;
let result = string_array
.iter()
@@ -383,8 +370,8 @@ pub fn ltrim<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> {
Ok(Arc::new(result) as ArrayRef)
}
2 => {
- let string_array = downcast_string_arg!(args[0], "string", T);
- let characters_array = downcast_string_arg!(args[1], "characters", T);
+ let string_array = as_generic_string_array::<T>(&args[0])?;
+ let characters_array = as_generic_string_array::<T>(&args[1])?;
let result = string_array
.iter()
@@ -410,7 +397,7 @@ pub fn ltrim<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> {
/// Repeats string the specified number of times.
/// repeat('Pg', 4) = 'PgPgPgPg'
pub fn repeat<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> {
- let string_array = downcast_string_arg!(args[0], "string", T);
+ let string_array = as_generic_string_array::<T>(&args[0])?;
let number_array = as_int64_array(&args[1])?;
let result = string_array
@@ -428,9 +415,9 @@ pub fn repeat<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> {
/// Replaces all occurrences in string of substring from with substring to.
/// replace('abcdefabcdef', 'cd', 'XX') = 'abXXefabXXef'
pub fn replace<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> {
- let string_array = downcast_string_arg!(args[0], "string", T);
- let from_array = downcast_string_arg!(args[1], "from", T);
- let to_array = downcast_string_arg!(args[2], "to", T);
+ let string_array = as_generic_string_array::<T>(&args[0])?;
+ let from_array = as_generic_string_array::<T>(&args[1])?;
+ let to_array = as_generic_string_array::<T>(&args[2])?;
let result = string_array
.iter()
@@ -450,7 +437,7 @@ pub fn replace<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> {
pub fn rtrim<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> {
match args.len() {
1 => {
- let string_array = downcast_string_arg!(args[0], "string", T);
+ let string_array = as_generic_string_array::<T>(&args[0])?;
let result = string_array
.iter()
@@ -460,8 +447,8 @@ pub fn rtrim<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> {
Ok(Arc::new(result) as ArrayRef)
}
2 => {
- let string_array = downcast_string_arg!(args[0], "string", T);
- let characters_array = downcast_string_arg!(args[1], "characters", T);
+ let string_array = as_generic_string_array::<T>(&args[0])?;
+ let characters_array = as_generic_string_array::<T>(&args[1])?;
let result = string_array
.iter()
@@ -487,8 +474,8 @@ pub fn rtrim<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> {
/// Splits string at occurrences of delimiter and returns the n'th field (counting from one).
/// split_part('abc~@~def~@~ghi', '~@~', 2) = 'def'
pub fn split_part<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> {
- let string_array = downcast_string_arg!(args[0], "string", T);
- let delimiter_array = downcast_string_arg!(args[1], "delimiter", T);
+ let string_array = as_generic_string_array::<T>(&args[0])?;
+ let delimiter_array = as_generic_string_array::<T>(&args[1])?;
let n_array = as_int64_array(&args[2])?;
let result = string_array
.iter()
@@ -518,8 +505,8 @@ pub fn split_part<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> {
/// Returns true if string starts with prefix.
/// starts_with('alphabet', 'alph') = 't'
pub fn starts_with<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> {
- let string_array = downcast_string_arg!(args[0], "string", T);
- let prefix_array = downcast_string_arg!(args[1], "prefix", T);
+ let string_array = as_generic_string_array::<T>(&args[0])?;
+ let prefix_array = as_generic_string_array::<T>(&args[1])?;
let result = string_array
.iter()
diff --git a/datafusion/physical-expr/src/unicode_expressions.rs b/datafusion/physical-expr/src/unicode_expressions.rs
index 0e35521c4..37180bff5 100644
--- a/datafusion/physical-expr/src/unicode_expressions.rs
+++ b/datafusion/physical-expr/src/unicode_expressions.rs
@@ -25,27 +25,13 @@ use arrow::{
array::{ArrayRef, GenericStringArray, Int64Array, OffsetSizeTrait, PrimitiveArray},
datatypes::{ArrowNativeType, ArrowPrimitiveType},
};
-use datafusion_common::{DataFusionError, Result};
+use datafusion_common::{cast::as_generic_string_array, DataFusionError, Result};
use hashbrown::HashMap;
use std::cmp::Ordering;
use std::sync::Arc;
use std::{any::type_name, cmp::max};
use unicode_segmentation::UnicodeSegmentation;
-macro_rules! downcast_string_arg {
- ($ARG:expr, $NAME:expr, $T:ident) => {{
- $ARG.as_any()
- .downcast_ref::<GenericStringArray<T>>()
- .ok_or_else(|| {
- DataFusionError::Internal(format!(
- "could not cast {} to {}",
- $NAME,
- type_name::<GenericStringArray<T>>()
- ))
- })?
- }};
-}
-
macro_rules! downcast_arg {
($ARG:expr, $NAME:expr, $ARRAY_TYPE:ident) => {{
$ARG.as_any().downcast_ref::<$ARRAY_TYPE>().ok_or_else(|| {
@@ -65,12 +51,8 @@ pub fn character_length<T: ArrowPrimitiveType>(args: &[ArrayRef]) -> Result<Arra
where
T::Native: OffsetSizeTrait,
{
- let string_array: &GenericStringArray<T::Native> = args[0]
- .as_any()
- .downcast_ref::<GenericStringArray<T::Native>>()
- .ok_or_else(|| {
- DataFusionError::Internal("could not cast string to StringArray".to_string())
- })?;
+ let string_array: &GenericStringArray<T::Native> =
+ as_generic_string_array::<T::Native>(&args[0])?;
let result = string_array
.iter()
@@ -89,7 +71,7 @@ where
/// left('abcde', 2) = 'ab'
/// The implementation uses UTF-8 code points as characters
pub fn left<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> {
- let string_array = downcast_string_arg!(args[0], "string", T);
+ let string_array = as_generic_string_array::<T>(&args[0])?;
let n_array = downcast_arg!(args[1], "n", Int64Array);
let result = string_array
.iter()
@@ -121,7 +103,7 @@ pub fn left<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> {
pub fn lpad<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> {
match args.len() {
2 => {
- let string_array = downcast_string_arg!(args[0], "string", T);
+ let string_array = as_generic_string_array::<T>(&args[0])?;
let length_array = downcast_arg!(args[1], "length", Int64Array);
let result = string_array
@@ -157,9 +139,9 @@ pub fn lpad<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> {
Ok(Arc::new(result) as ArrayRef)
}
3 => {
- let string_array = downcast_string_arg!(args[0], "string", T);
+ let string_array = as_generic_string_array::<T>(&args[0])?;
let length_array = downcast_arg!(args[1], "length", Int64Array);
- let fill_array = downcast_string_arg!(args[2], "fill", T);
+ let fill_array = as_generic_string_array::<T>(&args[2])?;
let result = string_array
.iter()
@@ -219,7 +201,7 @@ pub fn lpad<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> {
/// reverse('abcde') = 'edcba'
/// The implementation uses UTF-8 code points as characters
pub fn reverse<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> {
- let string_array = downcast_string_arg!(args[0], "string", T);
+ let string_array = as_generic_string_array::<T>(&args[0])?;
let result = string_array
.iter()
@@ -233,7 +215,7 @@ pub fn reverse<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> {
/// right('abcde', 2) = 'de'
/// The implementation uses UTF-8 code points as characters
pub fn right<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> {
- let string_array = downcast_string_arg!(args[0], "string", T);
+ let string_array = as_generic_string_array::<T>(&args[0])?;
let n_array = downcast_arg!(args[1], "n", Int64Array);
let result = string_array
@@ -267,7 +249,7 @@ pub fn right<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> {
pub fn rpad<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> {
match args.len() {
2 => {
- let string_array = downcast_string_arg!(args[0], "string", T);
+ let string_array = as_generic_string_array::<T>(&args[0])?;
let length_array = downcast_arg!(args[1], "length", Int64Array);
let result = string_array
@@ -302,9 +284,9 @@ pub fn rpad<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> {
Ok(Arc::new(result) as ArrayRef)
}
3 => {
- let string_array = downcast_string_arg!(args[0], "string", T);
+ let string_array = as_generic_string_array::<T>(&args[0])?;
let length_array = downcast_arg!(args[1], "length", Int64Array);
- let fill_array = downcast_string_arg!(args[2], "fill", T);
+ let fill_array = as_generic_string_array::<T>(&args[2])?;
let result = string_array
.iter()
@@ -359,21 +341,11 @@ pub fn strpos<T: ArrowPrimitiveType>(args: &[ArrayRef]) -> Result<ArrayRef>
where
T::Native: OffsetSizeTrait,
{
- let string_array: &GenericStringArray<T::Native> = args[0]
- .as_any()
- .downcast_ref::<GenericStringArray<T::Native>>()
- .ok_or_else(|| {
- DataFusionError::Internal("could not cast string to StringArray".to_string())
- })?;
-
- let substring_array: &GenericStringArray<T::Native> = args[1]
- .as_any()
- .downcast_ref::<GenericStringArray<T::Native>>()
- .ok_or_else(|| {
- DataFusionError::Internal(
- "could not cast substring to StringArray".to_string(),
- )
- })?;
+ let string_array: &GenericStringArray<T::Native> =
+ as_generic_string_array::<T::Native>(&args[0])?;
+
+ let substring_array: &GenericStringArray<T::Native> =
+ as_generic_string_array::<T::Native>(&args[1])?;
let result = string_array
.iter()
@@ -403,7 +375,7 @@ where
pub fn substr<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> {
match args.len() {
2 => {
- let string_array = downcast_string_arg!(args[0], "string", T);
+ let string_array = as_generic_string_array::<T>(&args[0])?;
let start_array = downcast_arg!(args[1], "start", Int64Array);
let result = string_array
@@ -424,7 +396,7 @@ pub fn substr<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> {
Ok(Arc::new(result) as ArrayRef)
}
3 => {
- let string_array = downcast_string_arg!(args[0], "string", T);
+ let string_array = as_generic_string_array::<T>(&args[0])?;
let start_array = downcast_arg!(args[1], "start", Int64Array);
let count_array = downcast_arg!(args[2], "count", Int64Array);
@@ -462,9 +434,9 @@ pub fn substr<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> {
/// Replaces each character in string that matches a character in the from set with the corresponding character in the to set. If from is longer than to, occurrences of the extra characters in from are deleted.
/// translate('12345', '143', 'ax') = 'a2x5'
pub fn translate<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> {
- let string_array = downcast_string_arg!(args[0], "string", T);
- let from_array = downcast_string_arg!(args[1], "from", T);
- let to_array = downcast_string_arg!(args[2], "to", T);
+ let string_array = as_generic_string_array::<T>(&args[0])?;
+ let from_array = as_generic_string_array::<T>(&args[1])?;
+ let to_array = as_generic_string_array::<T>(&args[2])?;
let result = string_array
.iter()
diff --git a/datafusion/row/src/writer.rs b/datafusion/row/src/writer.rs
index 02325c1d6..451d84f17 100644
--- a/datafusion/row/src/writer.rs
+++ b/datafusion/row/src/writer.rs
@@ -22,7 +22,9 @@ use arrow::array::*;
use arrow::datatypes::{DataType, Schema};
use arrow::record_batch::RecordBatch;
use arrow::util::bit_util::{round_upto_power_of_2, set_bit_raw, unset_bit_raw};
-use datafusion_common::cast::{as_binary_array, as_date32_array, as_string_array};
+use datafusion_common::cast::{
+ as_binary_array, as_date32_array, as_date64_array, as_string_array,
+};
use datafusion_common::Result;
use std::cmp::max;
use std::sync::Arc;
@@ -339,7 +341,7 @@ pub(crate) fn write_field_date64(
col_idx: usize,
row_idx: usize,
) {
- let from = from.as_any().downcast_ref::<Date64Array>().unwrap();
+ let from = as_date64_array(from).unwrap();
to.set_date64(col_idx, from.value(row_idx));
}