You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by al...@apache.org on 2022/11/27 11:48:15 UTC
[arrow-datafusion] branch master updated: improve error handling and add some more types (#4352)
This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/master by this push:
new da54fa584 improve error handling and add some more types (#4352)
da54fa584 is described below
commit da54fa584c7c4e8069ca6aed51e7c1f0dffc807c
Author: Burak <bu...@gmail.com>
AuthorDate: Sun Nov 27 14:48:10 2022 +0300
improve error handling and add some more types (#4352)
---
datafusion/common/src/cast.rs | 47 +++++++++++++++--
datafusion/common/src/scalar.rs | 8 +--
.../core/src/physical_plan/joins/hash_join.rs | 17 +++---
datafusion/core/src/physical_plan/sorts/sort.rs | 14 ++---
datafusion/core/src/physical_plan/windows/mod.rs | 7 +--
datafusion/core/tests/custom_sources.rs | 15 ++----
datafusion/core/tests/provider_filter_pushdown.rs | 7 +--
datafusion/core/tests/sql/parquet.rs | 20 ++------
datafusion/core/tests/user_defined_aggregates.rs | 7 +--
datafusion/physical-expr/src/aggregate/median.rs | 19 ++-----
datafusion/physical-expr/src/crypto_expressions.rs | 14 ++---
.../physical-expr/src/expressions/in_list.rs | 12 +++--
datafusion/physical-expr/src/hash_utils.rs | 12 +++--
datafusion/physical-expr/src/string_expressions.rs | 60 +++++-----------------
14 files changed, 118 insertions(+), 141 deletions(-)
diff --git a/datafusion/common/src/cast.rs b/datafusion/common/src/cast.rs
index 3d5036d7e..bca3dbde3 100644
--- a/datafusion/common/src/cast.rs
+++ b/datafusion/common/src/cast.rs
@@ -21,10 +21,14 @@
//! kernels in arrow-rs such as `as_boolean_array` do.
use crate::{downcast_value, DataFusionError};
-use arrow::array::{
- Array, BooleanArray, Date32Array, Decimal128Array, Float32Array, Float64Array,
- Int32Array, Int64Array, ListArray, StringArray, StructArray, UInt32Array,
- UInt64Array,
+use arrow::{
+ array::{
+ Array, BooleanArray, Date32Array, Decimal128Array, DictionaryArray, Float32Array,
+ Float64Array, GenericBinaryArray, GenericListArray, Int32Array, Int64Array,
+ LargeListArray, ListArray, OffsetSizeTrait, PrimitiveArray, StringArray,
+ StructArray, UInt32Array, UInt64Array,
+ },
+ datatypes::{ArrowDictionaryKeyType, ArrowPrimitiveType},
};
// Downcast ArrayRef to Date32Array
@@ -88,3 +92,38 @@ pub fn as_boolean_array(array: &dyn Array) -> Result<&BooleanArray, DataFusionEr
pub fn as_list_array(array: &dyn Array) -> Result<&ListArray, DataFusionError> {
Ok(downcast_value!(array, ListArray))
}
+
+// Downcast ArrayRef to DictionaryArray
+pub fn as_dictionary_array<T: ArrowDictionaryKeyType>(
+ array: &dyn Array,
+) -> Result<&DictionaryArray<T>, DataFusionError> {
+ Ok(downcast_value!(array, DictionaryArray, T))
+}
+
+// Downcast ArrayRef to GenericBinaryArray
+pub fn as_generic_binary_array<T: OffsetSizeTrait>(
+ array: &dyn Array,
+) -> Result<&GenericBinaryArray<T>, DataFusionError> {
+ Ok(downcast_value!(array, GenericBinaryArray, T))
+}
+
+// Downcast ArrayRef to GenericListArray
+pub fn as_generic_list_array<T: OffsetSizeTrait>(
+ array: &dyn Array,
+) -> Result<&GenericListArray<T>, DataFusionError> {
+ Ok(downcast_value!(array, GenericListArray, T))
+}
+
+// Downcast ArrayRef to LargeListArray
+pub fn as_large_list_array(
+ array: &dyn Array,
+) -> Result<&LargeListArray, DataFusionError> {
+ Ok(downcast_value!(array, LargeListArray))
+}
+
+// Downcast ArrayRef to PrimitiveArray
+pub fn as_primitive_array<T: ArrowPrimitiveType>(
+ array: &dyn Array,
+) -> Result<&PrimitiveArray<T>, DataFusionError> {
+ Ok(downcast_value!(array, PrimitiveArray, T))
+}
diff --git a/datafusion/common/src/scalar.rs b/datafusion/common/src/scalar.rs
index 9a1119469..7f2ea5533 100644
--- a/datafusion/common/src/scalar.rs
+++ b/datafusion/common/src/scalar.rs
@@ -25,7 +25,9 @@ use std::ops::{Add, Sub};
use std::str::FromStr;
use std::{convert::TryFrom, fmt, iter::repeat, sync::Arc};
-use crate::cast::{as_decimal128_array, as_list_array, as_struct_array};
+use crate::cast::{
+ as_decimal128_array, as_dictionary_array, as_list_array, as_struct_array,
+};
use crate::delta::shift_months;
use crate::error::{DataFusionError, Result};
use arrow::{
@@ -721,7 +723,7 @@ fn get_dict_value<K: ArrowDictionaryKeyType>(
array: &ArrayRef,
index: usize,
) -> (&ArrayRef, Option<usize>) {
- let dict_array = as_dictionary_array::<K>(array);
+ let dict_array = as_dictionary_array::<K>(array).unwrap();
(dict_array.values(), dict_array.key(index))
}
@@ -3212,7 +3214,7 @@ mod tests {
];
let array = ScalarValue::iter_to_array(scalars.into_iter()).unwrap();
- let array = as_dictionary_array::<Int32Type>(&array);
+ let array = as_dictionary_array::<Int32Type>(&array).unwrap();
let values_array = as_string_array(array.values()).unwrap();
let values = array
diff --git a/datafusion/core/src/physical_plan/joins/hash_join.rs b/datafusion/core/src/physical_plan/joins/hash_join.rs
index 89583c03c..3017df623 100644
--- a/datafusion/core/src/physical_plan/joins/hash_join.rs
+++ b/datafusion/core/src/physical_plan/joins/hash_join.rs
@@ -22,12 +22,11 @@ use ahash::RandomState;
use arrow::{
array::{
- as_dictionary_array, ArrayData, ArrayRef, BooleanArray, Date32Array, Date64Array,
- Decimal128Array, DictionaryArray, LargeStringArray, PrimitiveArray,
- Time32MillisecondArray, Time32SecondArray, Time64MicrosecondArray,
- Time64NanosecondArray, TimestampMicrosecondArray, TimestampMillisecondArray,
- TimestampSecondArray, UInt32BufferBuilder, UInt32Builder, UInt64BufferBuilder,
- UInt64Builder,
+ ArrayData, ArrayRef, BooleanArray, Date32Array, Date64Array, Decimal128Array,
+ DictionaryArray, LargeStringArray, PrimitiveArray, Time32MillisecondArray,
+ Time32SecondArray, Time64MicrosecondArray, Time64NanosecondArray,
+ TimestampMicrosecondArray, TimestampMillisecondArray, TimestampSecondArray,
+ UInt32BufferBuilder, UInt32Builder, UInt64BufferBuilder, UInt64Builder,
},
compute,
datatypes::{
@@ -54,7 +53,7 @@ use arrow::array::{
UInt8Array,
};
-use datafusion_common::cast::{as_boolean_array, as_string_array};
+use datafusion_common::cast::{as_boolean_array, as_dictionary_array, as_string_array};
use hashbrown::raw::RawTable;
@@ -1127,9 +1126,9 @@ macro_rules! equal_rows_elem {
macro_rules! equal_rows_elem_with_string_dict {
($key_array_type:ident, $l: ident, $r: ident, $left: ident, $right: ident, $null_equals_null: ident) => {{
let left_array: &DictionaryArray<$key_array_type> =
- as_dictionary_array::<$key_array_type>($l);
+ as_dictionary_array::<$key_array_type>($l).unwrap();
let right_array: &DictionaryArray<$key_array_type> =
- as_dictionary_array::<$key_array_type>($r);
+ as_dictionary_array::<$key_array_type>($r).unwrap();
let (left_values, left_values_index) = {
let keys_col = left_array.keys();
diff --git a/datafusion/core/src/physical_plan/sorts/sort.rs b/datafusion/core/src/physical_plan/sorts/sort.rs
index 0dfac7876..6e87ba76c 100644
--- a/datafusion/core/src/physical_plan/sorts/sort.rs
+++ b/datafusion/core/src/physical_plan/sorts/sort.rs
@@ -951,7 +951,7 @@ mod tests {
use arrow::array::*;
use arrow::compute::SortOptions;
use arrow::datatypes::*;
- use datafusion_common::cast::as_string_array;
+ use datafusion_common::cast::{as_primitive_array, as_string_array};
use futures::FutureExt;
use std::collections::{BTreeMap, HashMap};
@@ -995,11 +995,11 @@ mod tests {
assert_eq!(c1.value(0), "a");
assert_eq!(c1.value(c1.len() - 1), "e");
- let c2 = as_primitive_array::<UInt32Type>(&columns[1]);
+ let c2 = as_primitive_array::<UInt32Type>(&columns[1])?;
assert_eq!(c2.value(0), 1);
assert_eq!(c2.value(c2.len() - 1), 5,);
- let c7 = as_primitive_array::<UInt8Type>(&columns[6]);
+ let c7 = as_primitive_array::<UInt8Type>(&columns[6])?;
assert_eq!(c7.value(0), 15);
assert_eq!(c7.value(c7.len() - 1), 254,);
@@ -1067,11 +1067,11 @@ mod tests {
assert_eq!(c1.value(0), "a");
assert_eq!(c1.value(c1.len() - 1), "e");
- let c2 = as_primitive_array::<UInt32Type>(&columns[1]);
+ let c2 = as_primitive_array::<UInt32Type>(&columns[1])?;
assert_eq!(c2.value(0), 1);
assert_eq!(c2.value(c2.len() - 1), 5,);
- let c7 = as_primitive_array::<UInt8Type>(&columns[6]);
+ let c7 = as_primitive_array::<UInt8Type>(&columns[6])?;
assert_eq!(c7.value(0), 15);
assert_eq!(c7.value(c7.len() - 1), 254,);
@@ -1271,8 +1271,8 @@ mod tests {
assert_eq!(DataType::Float32, *columns[0].data_type());
assert_eq!(DataType::Float64, *columns[1].data_type());
- let a = as_primitive_array::<Float32Type>(&columns[0]);
- let b = as_primitive_array::<Float64Type>(&columns[1]);
+ let a = as_primitive_array::<Float32Type>(&columns[0])?;
+ let b = as_primitive_array::<Float64Type>(&columns[1])?;
// convert result to strings to allow comparing to expected result containing NaN
let result: Vec<(Option<String>, Option<String>)> = (0..result[0].num_rows())
diff --git a/datafusion/core/src/physical_plan/windows/mod.rs b/datafusion/core/src/physical_plan/windows/mod.rs
index a488c6ffa..60ddb6f5d 100644
--- a/datafusion/core/src/physical_plan/windows/mod.rs
+++ b/datafusion/core/src/physical_plan/windows/mod.rs
@@ -171,6 +171,7 @@ mod tests {
use arrow::array::*;
use arrow::datatypes::{DataType, Field, SchemaRef};
use arrow::record_batch::RecordBatch;
+ use datafusion_common::cast::as_primitive_array;
use futures::FutureExt;
fn create_test_schema(partitions: usize) -> Result<(Arc<CsvExec>, SchemaRef)> {
@@ -228,15 +229,15 @@ mod tests {
// c3 is small int
- let count: &Int64Array = as_primitive_array(&columns[0]);
+ let count: &Int64Array = as_primitive_array(&columns[0])?;
assert_eq!(count.value(0), 100);
assert_eq!(count.value(99), 100);
- let max: &Int8Array = as_primitive_array(&columns[1]);
+ let max: &Int8Array = as_primitive_array(&columns[1])?;
assert_eq!(max.value(0), 125);
assert_eq!(max.value(99), 125);
- let min: &Int8Array = as_primitive_array(&columns[2]);
+ let min: &Int8Array = as_primitive_array(&columns[2])?;
assert_eq!(min.value(0), -117);
assert_eq!(min.value(99), -117);
diff --git a/datafusion/core/tests/custom_sources.rs b/datafusion/core/tests/custom_sources.rs
index 317f2983b..fde296d85 100644
--- a/datafusion/core/tests/custom_sources.rs
+++ b/datafusion/core/tests/custom_sources.rs
@@ -15,7 +15,7 @@
// specific language governing permissions and limitations
// under the License.
-use arrow::array::{Int32Array, Int64Array, PrimitiveArray};
+use arrow::array::{Int32Array, Int64Array};
use arrow::compute::kernels::aggregate;
use arrow::datatypes::{DataType, Field, Int32Type, Schema, SchemaRef};
use arrow::error::Result as ArrowResult;
@@ -38,6 +38,7 @@ use datafusion::{
};
use datafusion::{error::Result, physical_plan::DisplayFormatType};
+use datafusion_common::cast::as_primitive_array;
use futures::stream::Stream;
use std::any::Any;
use std::pin::Pin;
@@ -162,18 +163,10 @@ impl ExecutionPlan for CustomExecutionPlan {
.map(|i| ColumnStatistics {
null_count: Some(batch.column(*i).null_count()),
min_value: Some(ScalarValue::Int32(aggregate::min(
- batch
- .column(*i)
- .as_any()
- .downcast_ref::<PrimitiveArray<Int32Type>>()
- .unwrap(),
+ as_primitive_array::<Int32Type>(batch.column(*i)).unwrap(),
))),
max_value: Some(ScalarValue::Int32(aggregate::max(
- batch
- .column(*i)
- .as_any()
- .downcast_ref::<PrimitiveArray<Int32Type>>()
- .unwrap(),
+ as_primitive_array::<Int32Type>(batch.column(*i)).unwrap(),
))),
..Default::default()
})
diff --git a/datafusion/core/tests/provider_filter_pushdown.rs b/datafusion/core/tests/provider_filter_pushdown.rs
index 84c8c50cd..2cc57f8b0 100644
--- a/datafusion/core/tests/provider_filter_pushdown.rs
+++ b/datafusion/core/tests/provider_filter_pushdown.rs
@@ -15,7 +15,7 @@
// specific language governing permissions and limitations
// under the License.
-use arrow::array::{as_primitive_array, Int32Builder, Int64Array};
+use arrow::array::{Int32Builder, Int64Array};
use arrow::datatypes::{DataType, Field, Schema, SchemaRef};
use arrow::record_batch::RecordBatch;
use async_trait::async_trait;
@@ -31,6 +31,7 @@ use datafusion::physical_plan::{
};
use datafusion::prelude::*;
use datafusion::scalar::ScalarValue;
+use datafusion_common::cast::as_primitive_array;
use datafusion_common::DataFusionError;
use datafusion_expr::expr::{BinaryExpr, Cast};
use std::ops::Deref;
@@ -215,7 +216,7 @@ async fn assert_provider_row_count(value: i64, expected_count: i64) -> Result<()
.aggregate(vec![], vec![count(col("flag"))])?;
let results = df.collect().await?;
- let result_col: &Int64Array = as_primitive_array(results[0].column(0));
+ let result_col: &Int64Array = as_primitive_array(results[0].column(0))?;
assert_eq!(result_col.value(0), expected_count);
ctx.register_table("data", Arc::new(provider))?;
@@ -225,7 +226,7 @@ async fn assert_provider_row_count(value: i64, expected_count: i64) -> Result<()
.collect()
.await?;
- let sql_result_col: &Int64Array = as_primitive_array(sql_results[0].column(0));
+ let sql_result_col: &Int64Array = as_primitive_array(sql_results[0].column(0))?;
assert_eq!(sql_result_col.value(0), expected_count);
Ok(())
diff --git a/datafusion/core/tests/sql/parquet.rs b/datafusion/core/tests/sql/parquet.rs
index 7cf4f343f..18ebf57c0 100644
--- a/datafusion/core/tests/sql/parquet.rs
+++ b/datafusion/core/tests/sql/parquet.rs
@@ -19,7 +19,7 @@ use std::{fs, path::Path};
use ::parquet::arrow::ArrowWriter;
use datafusion::datasource::listing::ListingOptions;
-use datafusion_common::cast::{as_list_array, as_string_array};
+use datafusion_common::cast::{as_list_array, as_primitive_array, as_string_array};
use tempfile::TempDir;
use super::*;
@@ -239,11 +239,7 @@ async fn parquet_list_columns() {
let utf8_list_array = as_list_array(batch.column(1)).unwrap();
assert_eq!(
- int_list_array
- .value(0)
- .as_any()
- .downcast_ref::<PrimitiveArray<Int64Type>>()
- .unwrap(),
+ as_primitive_array::<Int64Type>(&int_list_array.value(0)).unwrap(),
&PrimitiveArray::<Int64Type>::from(vec![Some(1), Some(2), Some(3),])
);
@@ -253,22 +249,14 @@ async fn parquet_list_columns() {
);
assert_eq!(
- int_list_array
- .value(1)
- .as_any()
- .downcast_ref::<PrimitiveArray<Int64Type>>()
- .unwrap(),
+ as_primitive_array::<Int64Type>(&int_list_array.value(1)).unwrap(),
&PrimitiveArray::<Int64Type>::from(vec![None, Some(1),])
);
assert!(utf8_list_array.is_null(1));
assert_eq!(
- int_list_array
- .value(2)
- .as_any()
- .downcast_ref::<PrimitiveArray<Int64Type>>()
- .unwrap(),
+ as_primitive_array::<Int64Type>(&int_list_array.value(2)).unwrap(),
&PrimitiveArray::<Int64Type>::from(vec![Some(4),])
);
diff --git a/datafusion/core/tests/user_defined_aggregates.rs b/datafusion/core/tests/user_defined_aggregates.rs
index 2903d4272..fd8ddb832 100644
--- a/datafusion/core/tests/user_defined_aggregates.rs
+++ b/datafusion/core/tests/user_defined_aggregates.rs
@@ -22,7 +22,7 @@ use std::sync::Arc;
use datafusion::{
arrow::{
- array::{as_primitive_array, ArrayRef, Float64Array, TimestampNanosecondArray},
+ array::{ArrayRef, Float64Array, TimestampNanosecondArray},
datatypes::{DataType, Field, Float64Type, TimeUnit, TimestampNanosecondType},
record_batch::RecordBatch,
},
@@ -37,6 +37,7 @@ use datafusion::{
prelude::SessionContext,
scalar::ScalarValue,
};
+use datafusion_common::cast::as_primitive_array;
#[tokio::test]
/// Basic query for with a udaf returning a structure
@@ -227,8 +228,8 @@ impl Accumulator for FirstSelector {
fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
// cast argumets to the appropriate type (DataFusion will type
// check these based on the declared allowed input types)
- let v = as_primitive_array::<Float64Type>(&values[0]);
- let t = as_primitive_array::<TimestampNanosecondType>(&values[1]);
+ let v = as_primitive_array::<Float64Type>(&values[0])?;
+ let t = as_primitive_array::<TimestampNanosecondType>(&values[1])?;
// Update the actual values
for (value, time) in v.iter().zip(t.iter()) {
diff --git a/datafusion/physical-expr/src/aggregate/median.rs b/datafusion/physical-expr/src/aggregate/median.rs
index deef5dec2..64d6fa7b4 100644
--- a/datafusion/physical-expr/src/aggregate/median.rs
+++ b/datafusion/physical-expr/src/aggregate/median.rs
@@ -19,12 +19,13 @@
use crate::expressions::format_state_name;
use crate::{AggregateExpr, PhysicalExpr};
-use arrow::array::{Array, ArrayRef, PrimitiveArray, PrimitiveBuilder};
+use arrow::array::{Array, ArrayRef, PrimitiveBuilder};
use arrow::compute::sort;
use arrow::datatypes::{
ArrowPrimitiveType, DataType, Field, Float32Type, Float64Type, Int16Type, Int32Type,
Int64Type, Int8Type, UInt16Type, UInt32Type, UInt64Type, UInt8Type,
};
+use datafusion_common::cast::as_primitive_array;
use datafusion_common::{DataFusionError, Result, ScalarValue};
use datafusion_expr::{Accumulator, AggregateState};
use std::any::Any;
@@ -102,12 +103,7 @@ macro_rules! median {
return Ok(ScalarValue::Null);
}
let sorted = sort(&combined, None)?;
- let array = sorted
- .as_any()
- .downcast_ref::<PrimitiveArray<$TY>>()
- .ok_or(DataFusionError::Internal(
- "median! macro failed to cast array to expected type".to_string(),
- ))?;
+ let array = as_primitive_array::<$TY>(&sorted)?;
let len = sorted.len();
let mid = len / 2;
if len % 2 == 0 {
@@ -209,14 +205,7 @@ fn combine_arrays<T: ArrowPrimitiveType>(arrays: &[ArrayRef]) -> Result<ArrayRef
let len = arrays.iter().map(|a| a.len() - a.null_count()).sum();
let mut builder: PrimitiveBuilder<T> = PrimitiveBuilder::with_capacity(len);
for array in arrays {
- let array = array
- .as_any()
- .downcast_ref::<PrimitiveArray<T>>()
- .ok_or_else(|| {
- DataFusionError::Internal(
- "combine_arrays failed to cast array to expected type".to_string(),
- )
- })?;
+ let array = as_primitive_array::<T>(array)?;
for i in 0..array.len() {
if !array.is_null(i) {
builder.append_value(array.value(i));
diff --git a/datafusion/physical-expr/src/crypto_expressions.rs b/datafusion/physical-expr/src/crypto_expressions.rs
index 85f3ebdb5..89422399e 100644
--- a/datafusion/physical-expr/src/crypto_expressions.rs
+++ b/datafusion/physical-expr/src/crypto_expressions.rs
@@ -19,13 +19,13 @@
use arrow::{
array::{
- Array, ArrayRef, BinaryArray, GenericBinaryArray, GenericStringArray,
- OffsetSizeTrait, StringArray,
+ Array, ArrayRef, BinaryArray, GenericStringArray, OffsetSizeTrait, StringArray,
},
datatypes::DataType,
};
use blake2::{Blake2b512, Blake2s256, Digest};
use blake3::Hasher as Blake3;
+use datafusion_common::cast::as_generic_binary_array;
use datafusion_common::ScalarValue;
use datafusion_common::{DataFusionError, Result};
use datafusion_expr::ColumnarValue;
@@ -134,15 +134,7 @@ impl DigestAlgorithm {
where
T: OffsetSizeTrait,
{
- let input_value = value
- .as_any()
- .downcast_ref::<GenericBinaryArray<T>>()
- .ok_or_else(|| {
- DataFusionError::Internal(format!(
- "could not cast value to {}",
- type_name::<GenericBinaryArray<T>>()
- ))
- })?;
+ let input_value = as_generic_binary_array::<T>(value)?;
let array: ArrayRef = match self {
Self::Md5 => digest_to_array!(Md5, input_value),
Self::Sha224 => digest_to_array!(Sha224, input_value),
diff --git a/datafusion/physical-expr/src/expressions/in_list.rs b/datafusion/physical-expr/src/expressions/in_list.rs
index 63fe2292a..10efb8c30 100644
--- a/datafusion/physical-expr/src/expressions/in_list.rs
+++ b/datafusion/physical-expr/src/expressions/in_list.rs
@@ -33,7 +33,9 @@ use arrow::record_batch::RecordBatch;
use arrow::util::bit_iterator::BitIndexIterator;
use arrow::{downcast_dictionary_array, downcast_primitive_array};
use datafusion_common::{
- cast::{as_boolean_array, as_string_array},
+ cast::{
+ as_boolean_array, as_generic_binary_array, as_primitive_array, as_string_array,
+ },
DataFusionError, Result, ScalarValue,
};
use datafusion_expr::ColumnarValue;
@@ -178,11 +180,11 @@ fn make_set(array: &dyn Array) -> Result<Box<dyn Set>> {
Box::new(ArraySet::new(array, make_hash_set(array)))
},
DataType::Decimal128(_, _) => {
- let array = as_primitive_array::<Decimal128Type>(array);
+ let array = as_primitive_array::<Decimal128Type>(array)?;
Box::new(ArraySet::new(array, make_hash_set(array)))
}
DataType::Decimal256(_, _) => {
- let array = as_primitive_array::<Decimal256Type>(array);
+ let array = as_primitive_array::<Decimal256Type>(array)?;
Box::new(ArraySet::new(array, make_hash_set(array)))
}
DataType::Utf8 => {
@@ -194,11 +196,11 @@ fn make_set(array: &dyn Array) -> Result<Box<dyn Set>> {
Box::new(ArraySet::new(array, make_hash_set(array)))
}
DataType::Binary => {
- let array = as_generic_binary_array::<i32>(array);
+ let array = as_generic_binary_array::<i32>(array)?;
Box::new(ArraySet::new(array, make_hash_set(array)))
}
DataType::LargeBinary => {
- let array = as_generic_binary_array::<i64>(array);
+ let array = as_generic_binary_array::<i64>(array)?;
Box::new(ArraySet::new(array, make_hash_set(array)))
}
DataType::Dictionary(_, _) => unreachable!("dictionary should have been flattened"),
diff --git a/datafusion/physical-expr/src/hash_utils.rs b/datafusion/physical-expr/src/hash_utils.rs
index d6cde1e7e..c687eb80e 100644
--- a/datafusion/physical-expr/src/hash_utils.rs
+++ b/datafusion/physical-expr/src/hash_utils.rs
@@ -23,7 +23,9 @@ use arrow::datatypes::*;
use arrow::{downcast_dictionary_array, downcast_primitive_array};
use arrow_buffer::i256;
use datafusion_common::{
- cast::{as_boolean_array, as_string_array},
+ cast::{
+ as_boolean_array, as_generic_binary_array, as_primitive_array, as_string_array,
+ },
DataFusionError, Result,
};
use std::sync::Arc;
@@ -217,18 +219,18 @@ pub fn create_hashes<'a>(
DataType::Boolean => hash_array(as_boolean_array(array)?, random_state, hashes_buffer, multi_col),
DataType::Utf8 => hash_array(as_string_array(array)?, random_state, hashes_buffer, multi_col),
DataType::LargeUtf8 => hash_array(as_largestring_array(array), random_state, hashes_buffer, multi_col),
- DataType::Binary => hash_array(as_generic_binary_array::<i32>(array), random_state, hashes_buffer, multi_col),
- DataType::LargeBinary => hash_array(as_generic_binary_array::<i64>(array), random_state, hashes_buffer, multi_col),
+ DataType::Binary => hash_array(as_generic_binary_array::<i32>(array)?, random_state, hashes_buffer, multi_col),
+ DataType::LargeBinary => hash_array(as_generic_binary_array::<i64>(array)?, random_state, hashes_buffer, multi_col),
DataType::FixedSizeBinary(_) => {
let array: &FixedSizeBinaryArray = array.as_any().downcast_ref().unwrap();
hash_array(array, random_state, hashes_buffer, multi_col)
}
DataType::Decimal128(_, _) => {
- let array = as_primitive_array::<Decimal128Type>(array);
+ let array = as_primitive_array::<Decimal128Type>(array)?;
hash_array(array, random_state, hashes_buffer, multi_col)
}
DataType::Decimal256(_, _) => {
- let array = as_primitive_array::<Decimal256Type>(array);
+ let array = as_primitive_array::<Decimal256Type>(array)?;
hash_array(array, random_state, hashes_buffer, multi_col)
}
DataType::Dictionary(_, _) => downcast_dictionary_array! {
diff --git a/datafusion/physical-expr/src/string_expressions.rs b/datafusion/physical-expr/src/string_expressions.rs
index f06ad6dc2..2199c009a 100644
--- a/datafusion/physical-expr/src/string_expressions.rs
+++ b/datafusion/physical-expr/src/string_expressions.rs
@@ -23,12 +23,15 @@
use arrow::{
array::{
- Array, ArrayRef, BooleanArray, GenericStringArray, Int32Array, Int64Array,
- OffsetSizeTrait, PrimitiveArray, StringArray,
+ Array, ArrayRef, BooleanArray, GenericStringArray, Int32Array, OffsetSizeTrait,
+ StringArray,
},
datatypes::{ArrowNativeType, ArrowPrimitiveType, DataType},
};
-use datafusion_common::{cast::as_string_array, ScalarValue};
+use datafusion_common::{
+ cast::{as_int64_array, as_primitive_array, as_string_array},
+ ScalarValue,
+};
use datafusion_common::{DataFusionError, Result};
use datafusion_expr::ColumnarValue;
use std::any::type_name;
@@ -50,43 +53,6 @@ macro_rules! downcast_string_arg {
}};
}
-macro_rules! downcast_primitive_array_arg {
- ($ARG:expr, $NAME:expr, $T:ident) => {{
- $ARG.as_any()
- .downcast_ref::<PrimitiveArray<T>>()
- .ok_or_else(|| {
- DataFusionError::Internal(format!(
- "could not cast {} to {}",
- $NAME,
- type_name::<PrimitiveArray<T>>()
- ))
- })?
- }};
-}
-
-macro_rules! downcast_arg {
- ($ARG:expr, $NAME:expr, $ARRAY_TYPE:ident) => {{
- $ARG.as_any().downcast_ref::<$ARRAY_TYPE>().ok_or_else(|| {
- DataFusionError::Internal(format!(
- "could not cast {} to {}",
- $NAME,
- type_name::<$ARRAY_TYPE>()
- ))
- })?
- }};
-}
-
-macro_rules! downcast_vec {
- ($ARGS:expr, $ARRAY_TYPE:ident) => {{
- $ARGS
- .iter()
- .map(|e| match e.as_any().downcast_ref::<$ARRAY_TYPE>() {
- Some(array) => Ok(array),
- _ => Err(DataFusionError::Internal("failed to downcast".to_string())),
- })
- }};
-}
-
/// applies a unary expression to `args[0]` that is expected to be downcastable to
/// a `GenericStringArray` and returns a `GenericStringArray` (which may have a different offset)
/// # Errors
@@ -236,7 +202,7 @@ pub fn btrim<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> {
/// Returns the character with the given code. chr(0) is disallowed because text data types cannot store that character.
/// chr(65) = 'A'
pub fn chr(args: &[ArrayRef]) -> Result<ArrayRef> {
- let integer_array = downcast_arg!(args[0], "integer", Int64Array);
+ let integer_array = as_int64_array(&args[0])?;
// first map is the iterator, second is for the `Option<_>`
let result = integer_array
@@ -329,7 +295,10 @@ pub fn concat(args: &[ColumnarValue]) -> Result<ColumnarValue> {
/// concat_ws(',', 'abcde', 2, NULL, 22) = 'abcde,2,22'
pub fn concat_ws(args: &[ArrayRef]) -> Result<ArrayRef> {
// downcast all arguments to strings
- let args = downcast_vec!(args, StringArray).collect::<Result<Vec<&StringArray>>>()?;
+ let args = args
+ .iter()
+ .map(|e| as_string_array(e))
+ .collect::<Result<Vec<&StringArray>>>()?;
// do not accept 0 or 1 arguments.
if args.len() < 2 {
@@ -442,7 +411,7 @@ pub fn ltrim<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> {
/// repeat('Pg', 4) = 'PgPgPgPg'
pub fn repeat<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> {
let string_array = downcast_string_arg!(args[0], "string", T);
- let number_array = downcast_arg!(args[1], "number", Int64Array);
+ let number_array = as_int64_array(&args[1])?;
let result = string_array
.iter()
@@ -520,8 +489,7 @@ pub fn rtrim<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> {
pub fn split_part<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> {
let string_array = downcast_string_arg!(args[0], "string", T);
let delimiter_array = downcast_string_arg!(args[1], "delimiter", T);
- let n_array = downcast_arg!(args[2], "n", Int64Array);
-
+ let n_array = as_int64_array(&args[2])?;
let result = string_array
.iter()
.zip(delimiter_array.iter())
@@ -571,7 +539,7 @@ pub fn to_hex<T: ArrowPrimitiveType>(args: &[ArrayRef]) -> Result<ArrayRef>
where
T::Native: OffsetSizeTrait,
{
- let integer_array = downcast_primitive_array_arg!(args[0], "integer", T);
+ let integer_array = as_primitive_array::<T>(&args[0])?;
let result = integer_array
.iter()