You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by al...@apache.org on 2022/11/22 22:22:53 UTC
[arrow-datafusion] branch master updated: Refactor downcasting functions with downcastvalue macro and improve error handling of `ListArray` downcasting (#4313)
This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/master by this push:
new bfce07652 Refactor downcasting functions with downcastvalue macro and improve error handling of `ListArray` downcasting (#4313)
bfce07652 is described below
commit bfce076527e40e51357b3815ed367a28e8b73b3a
Author: Burak <bu...@gmail.com>
AuthorDate: Wed Nov 23 01:22:47 2022 +0300
Refactor downcasting functions with downcastvalue macro and improve error handling of `ListArray` downcasting (#4313)
* refactor casting with downcastvalue macro and add list array downcasting
* fix clippy
---
datafusion/common/src/cast.rs | 99 +++++-----------------
datafusion/common/src/scalar.rs | 17 ++--
.../core/src/avro_to_arrow/arrow_array_reader.rs | 10 +--
datafusion/core/tests/sql/parquet.rs | 14 +--
.../physical-expr/src/aggregate/count_distinct.rs | 6 +-
.../src/expressions/get_indexed_field.rs | 11 +--
datafusion/physical-expr/src/functions.rs | 10 +--
7 files changed, 41 insertions(+), 126 deletions(-)
diff --git a/datafusion/common/src/cast.rs b/datafusion/common/src/cast.rs
index 940168947..3d5036d7e 100644
--- a/datafusion/common/src/cast.rs
+++ b/datafusion/common/src/cast.rs
@@ -20,132 +20,71 @@
//! but provide an error message rather than a panic, as the corresponding
//! kernels in arrow-rs such as `as_boolean_array` do.
-use crate::DataFusionError;
+use crate::{downcast_value, DataFusionError};
use arrow::array::{
Array, BooleanArray, Date32Array, Decimal128Array, Float32Array, Float64Array,
- Int32Array, Int64Array, StringArray, StructArray, UInt32Array, UInt64Array,
+ Int32Array, Int64Array, ListArray, StringArray, StructArray, UInt32Array,
+ UInt64Array,
};
// Downcast ArrayRef to Date32Array
pub fn as_date32_array(array: &dyn Array) -> Result<&Date32Array, DataFusionError> {
- array.as_any().downcast_ref::<Date32Array>().ok_or_else(|| {
- DataFusionError::Internal(format!(
- "Expected a Date32Array, got: {}",
- array.data_type()
- ))
- })
+ Ok(downcast_value!(array, Date32Array))
}
// Downcast ArrayRef to StructArray
pub fn as_struct_array(array: &dyn Array) -> Result<&StructArray, DataFusionError> {
- array.as_any().downcast_ref::<StructArray>().ok_or_else(|| {
- DataFusionError::Internal(format!(
- "Expected a StructArray, got: {}",
- array.data_type()
- ))
- })
+ Ok(downcast_value!(array, StructArray))
}
// Downcast ArrayRef to Int32Array
pub fn as_int32_array(array: &dyn Array) -> Result<&Int32Array, DataFusionError> {
- array.as_any().downcast_ref::<Int32Array>().ok_or_else(|| {
- DataFusionError::Internal(format!(
- "Expected a Int32Array, got: {}",
- array.data_type()
- ))
- })
+ Ok(downcast_value!(array, Int32Array))
}
// Downcast ArrayRef to Int64Array
pub fn as_int64_array(array: &dyn Array) -> Result<&Int64Array, DataFusionError> {
- array.as_any().downcast_ref::<Int64Array>().ok_or_else(|| {
- DataFusionError::Internal(format!(
- "Expected a Int64Array, got: {}",
- array.data_type()
- ))
- })
+ Ok(downcast_value!(array, Int64Array))
}
// Downcast ArrayRef to Decimal128Array
pub fn as_decimal128_array(
array: &dyn Array,
) -> Result<&Decimal128Array, DataFusionError> {
- array
- .as_any()
- .downcast_ref::<Decimal128Array>()
- .ok_or_else(|| {
- DataFusionError::Internal(format!(
- "Expected a Decimal128Array, got: {}",
- array.data_type()
- ))
- })
+ Ok(downcast_value!(array, Decimal128Array))
}
// Downcast ArrayRef to Float32Array
pub fn as_float32_array(array: &dyn Array) -> Result<&Float32Array, DataFusionError> {
- array
- .as_any()
- .downcast_ref::<Float32Array>()
- .ok_or_else(|| {
- DataFusionError::Internal(format!(
- "Expected a Float32Array, got: {}",
- array.data_type()
- ))
- })
+ Ok(downcast_value!(array, Float32Array))
}
// Downcast ArrayRef to Float64Array
pub fn as_float64_array(array: &dyn Array) -> Result<&Float64Array, DataFusionError> {
- array
- .as_any()
- .downcast_ref::<Float64Array>()
- .ok_or_else(|| {
- DataFusionError::Internal(format!(
- "Expected a Float64Array, got: {}",
- array.data_type()
- ))
- })
+ Ok(downcast_value!(array, Float64Array))
}
// Downcast ArrayRef to StringArray
pub fn as_string_array(array: &dyn Array) -> Result<&StringArray, DataFusionError> {
- array.as_any().downcast_ref::<StringArray>().ok_or_else(|| {
- DataFusionError::Internal(format!(
- "Expected a StringArray, got: {}",
- array.data_type()
- ))
- })
+ Ok(downcast_value!(array, StringArray))
}
// Downcast ArrayRef to UInt32Array
pub fn as_uint32_array(array: &dyn Array) -> Result<&UInt32Array, DataFusionError> {
- array.as_any().downcast_ref::<UInt32Array>().ok_or_else(|| {
- DataFusionError::Internal(format!(
- "Expected a UInt32Array, got: {}",
- array.data_type()
- ))
- })
+ Ok(downcast_value!(array, UInt32Array))
}
// Downcast ArrayRef to UInt64Array
pub fn as_uint64_array(array: &dyn Array) -> Result<&UInt64Array, DataFusionError> {
- array.as_any().downcast_ref::<UInt64Array>().ok_or_else(|| {
- DataFusionError::Internal(format!(
- "Expected a UInt64Array, got: {}",
- array.data_type()
- ))
- })
+ Ok(downcast_value!(array, UInt64Array))
}
// Downcast ArrayRef to BooleanArray
pub fn as_boolean_array(array: &dyn Array) -> Result<&BooleanArray, DataFusionError> {
- array
- .as_any()
- .downcast_ref::<BooleanArray>()
- .ok_or_else(|| {
- DataFusionError::Internal(format!(
- "Expected a BooleanArray, got: {}",
- array.data_type()
- ))
- })
+ Ok(downcast_value!(array, BooleanArray))
+}
+
+// Downcast ArrayRef to ListArray
+pub fn as_list_array(array: &dyn Array) -> Result<&ListArray, DataFusionError> {
+ Ok(downcast_value!(array, ListArray))
}
diff --git a/datafusion/common/src/scalar.rs b/datafusion/common/src/scalar.rs
index 44bf278b9..96d1ab672 100644
--- a/datafusion/common/src/scalar.rs
+++ b/datafusion/common/src/scalar.rs
@@ -24,7 +24,7 @@ use std::ops::{Add, Sub};
use std::str::FromStr;
use std::{convert::TryFrom, fmt, iter::repeat, sync::Arc};
-use crate::cast::{as_decimal128_array, as_struct_array};
+use crate::cast::{as_decimal128_array, as_list_array, as_struct_array};
use crate::delta::shift_months;
use crate::error::{DataFusionError, Result};
use arrow::{
@@ -2001,12 +2001,7 @@ impl ScalarValue {
DataType::Utf8 => typed_cast!(array, index, StringArray, Utf8),
DataType::LargeUtf8 => typed_cast!(array, index, LargeStringArray, LargeUtf8),
DataType::List(nested_type) => {
- let list_array =
- array.as_any().downcast_ref::<ListArray>().ok_or_else(|| {
- DataFusionError::Internal(
- "Failed to downcast ListArray".to_string(),
- )
- })?;
+ let list_array = as_list_array(array)?;
let value = match list_array.is_null(index) {
true => None,
false => {
@@ -2940,7 +2935,7 @@ mod tests {
Box::new(Field::new("item", DataType::UInt64, false)),
)
.to_array();
- let list_array = list_array_ref.as_any().downcast_ref::<ListArray>().unwrap();
+ let list_array = as_list_array(&list_array_ref).unwrap();
assert!(list_array.is_null(0));
assert_eq!(list_array.len(), 1);
@@ -2959,7 +2954,7 @@ mod tests {
)
.to_array();
- let list_array = list_array_ref.as_any().downcast_ref::<ListArray>().unwrap();
+ let list_array = as_list_array(&list_array_ref)?;
assert_eq!(list_array.len(), 1);
assert_eq!(list_array.values().len(), 3);
@@ -3758,7 +3753,7 @@ mod tests {
let nl2 = ScalarValue::new_list(Some(vec![s1]), s0.get_datatype());
// iter_to_array for list-of-struct
let array = ScalarValue::iter_to_array(vec![nl0, nl1, nl2]).unwrap();
- let array = array.as_any().downcast_ref::<ListArray>().unwrap();
+ let array = as_list_array(&array).unwrap();
// Construct expected array with array builders
let field_a_builder = StringBuilder::with_capacity(4, 1024);
@@ -3922,7 +3917,7 @@ mod tests {
);
let array = ScalarValue::iter_to_array(vec![l1, l2, l3]).unwrap();
- let array = array.as_any().downcast_ref::<ListArray>().unwrap();
+ let array = as_list_array(&array).unwrap();
// Construct expected array with array builders
let inner_builder = Int32Array::builder(8);
diff --git a/datafusion/core/src/avro_to_arrow/arrow_array_reader.rs b/datafusion/core/src/avro_to_arrow/arrow_array_reader.rs
index f40c5045a..30c32fc57 100644
--- a/datafusion/core/src/avro_to_arrow/arrow_array_reader.rs
+++ b/datafusion/core/src/avro_to_arrow/arrow_array_reader.rs
@@ -975,9 +975,9 @@ mod test {
use crate::arrow::array::Array;
use crate::arrow::datatypes::{Field, TimeUnit};
use crate::avro_to_arrow::{Reader, ReaderBuilder};
- use arrow::array::{ListArray, TimestampMicrosecondArray};
+ use arrow::array::TimestampMicrosecondArray;
use arrow::datatypes::DataType;
- use datafusion_common::cast::{as_int32_array, as_int64_array};
+ use datafusion_common::cast::{as_int32_array, as_int64_array, as_list_array};
use std::fs::File;
fn build_reader(name: &str, batch_size: usize) -> Reader<File> {
@@ -1034,11 +1034,7 @@ mod test {
let batch = reader.next().unwrap().unwrap();
assert_eq!(batch.num_columns(), 2);
assert_eq!(batch.num_rows(), 3);
- let a_array = batch
- .column(col_id_index)
- .as_any()
- .downcast_ref::<ListArray>()
- .unwrap();
+ let a_array = as_list_array(batch.column(col_id_index)).unwrap();
assert_eq!(
*a_array.data_type(),
DataType::List(Box::new(Field::new("bigint", DataType::Int64, true)))
diff --git a/datafusion/core/tests/sql/parquet.rs b/datafusion/core/tests/sql/parquet.rs
index e7b1fe5b1..7cf4f343f 100644
--- a/datafusion/core/tests/sql/parquet.rs
+++ b/datafusion/core/tests/sql/parquet.rs
@@ -19,7 +19,7 @@ use std::{fs, path::Path};
use ::parquet::arrow::ArrowWriter;
use datafusion::datasource::listing::ListingOptions;
-use datafusion_common::cast::as_string_array;
+use datafusion_common::cast::{as_list_array, as_string_array};
use tempfile::TempDir;
use super::*;
@@ -235,16 +235,8 @@ async fn parquet_list_columns() {
assert_eq!(2, batch.num_columns());
assert_eq!(schema, batch.schema());
- let int_list_array = batch
- .column(0)
- .as_any()
- .downcast_ref::<ListArray>()
- .unwrap();
- let utf8_list_array = batch
- .column(1)
- .as_any()
- .downcast_ref::<ListArray>()
- .unwrap();
+ let int_list_array = as_list_array(batch.column(0)).unwrap();
+ let utf8_list_array = as_list_array(batch.column(1)).unwrap();
assert_eq!(
int_list_array
diff --git a/datafusion/physical-expr/src/aggregate/count_distinct.rs b/datafusion/physical-expr/src/aggregate/count_distinct.rs
index 943f7b632..d4c0b4406 100644
--- a/datafusion/physical-expr/src/aggregate/count_distinct.rs
+++ b/datafusion/physical-expr/src/aggregate/count_distinct.rs
@@ -226,11 +226,11 @@ mod tests {
use crate::aggregate::utils::get_accum_scalar_values;
use arrow::array::{
ArrayRef, BooleanArray, Float32Array, Float64Array, Int16Array, Int32Array,
- Int64Array, Int8Array, ListArray, UInt16Array, UInt32Array, UInt64Array,
- UInt8Array,
+ Int64Array, Int8Array, UInt16Array, UInt32Array, UInt64Array, UInt8Array,
};
use arrow::array::{Int32Builder, ListBuilder, UInt64Builder};
use arrow::datatypes::DataType;
+ use datafusion_common::cast::as_list_array;
macro_rules! state_to_vec {
($LIST:expr, $DATA_TYPE:ident, $PRIM_TY:ty) => {{
@@ -380,7 +380,7 @@ mod tests {
let agg = DistinctCount::new(
arrays
.iter()
- .map(|a| a.as_any().downcast_ref::<ListArray>().unwrap())
+ .map(|a| as_list_array(a).unwrap())
.map(|a| a.values().data_type().clone())
.collect::<Vec<_>>(),
vec![],
diff --git a/datafusion/physical-expr/src/expressions/get_indexed_field.rs b/datafusion/physical-expr/src/expressions/get_indexed_field.rs
index 2b77a0b95..8fbb68d61 100644
--- a/datafusion/physical-expr/src/expressions/get_indexed_field.rs
+++ b/datafusion/physical-expr/src/expressions/get_indexed_field.rs
@@ -19,7 +19,6 @@
use crate::PhysicalExpr;
use arrow::array::Array;
-use arrow::array::ListArray;
use arrow::compute::concat;
use crate::physical_expr::down_cast_any_ref;
@@ -27,7 +26,7 @@ use arrow::{
datatypes::{DataType, Schema},
record_batch::RecordBatch,
};
-use datafusion_common::cast::as_struct_array;
+use datafusion_common::cast::{as_list_array, as_struct_array};
use datafusion_common::DataFusionError;
use datafusion_common::Result;
use datafusion_common::ScalarValue;
@@ -91,8 +90,7 @@ impl PhysicalExpr for GetIndexedFieldExpr {
Ok(ColumnarValue::Scalar(scalar_null))
}
(DataType::List(_), ScalarValue::Int64(Some(i))) => {
- let as_list_array =
- array.as_any().downcast_ref::<ListArray>().unwrap();
+ let as_list_array = as_list_array(&array)?;
if *i < 1 || as_list_array.is_empty() {
let scalar_null: ScalarValue = array.data_type().try_into()?;
@@ -349,10 +347,7 @@ mod tests {
let get_list_expr =
Arc::new(GetIndexedFieldExpr::new(struct_col_expr, list_field_key));
let result = get_list_expr.evaluate(&batch)?.into_array(batch.num_rows());
- let result = result
- .as_any()
- .downcast_ref::<ListArray>()
- .unwrap_or_else(|| panic!("failed to downcast to ListArray : {:?}", result));
+ let result = as_list_array(&result)?;
let expected =
&build_utf8_lists(list_of_tuples.into_iter().map(|t| t.1).collect());
assert_eq!(expected, result);
diff --git a/datafusion/physical-expr/src/functions.rs b/datafusion/physical-expr/src/functions.rs
index c84ee24c5..1ed83b89a 100644
--- a/datafusion/physical-expr/src/functions.rs
+++ b/datafusion/physical-expr/src/functions.rs
@@ -2847,8 +2847,7 @@ mod tests {
#[test]
#[cfg(feature = "regex_expressions")]
fn test_regexp_match() -> Result<()> {
- use arrow::array::ListArray;
- use datafusion_common::cast::as_string_array;
+ use datafusion_common::cast::{as_list_array, as_string_array};
let schema = Schema::new(vec![Field::new("a", DataType::Utf8, false)]);
let execution_props = ExecutionProps::new();
@@ -2873,7 +2872,7 @@ mod tests {
let result = expr.evaluate(&batch)?.into_array(batch.num_rows());
// downcast works
- let result = result.as_any().downcast_ref::<ListArray>().unwrap();
+ let result = as_list_array(&result)?;
let first_row = result.value(0);
let first_row = as_string_array(&first_row)?;
@@ -2887,8 +2886,7 @@ mod tests {
#[test]
#[cfg(feature = "regex_expressions")]
fn test_regexp_match_all_literals() -> Result<()> {
- use arrow::array::ListArray;
- use datafusion_common::cast::as_string_array;
+ use datafusion_common::cast::{as_list_array, as_string_array};
let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
let execution_props = ExecutionProps::new();
@@ -2913,7 +2911,7 @@ mod tests {
let result = expr.evaluate(&batch)?.into_array(batch.num_rows());
// downcast works
- let result = result.as_any().downcast_ref::<ListArray>().unwrap();
+ let result = as_list_array(&result)?;
let first_row = result.value(0);
let first_row = as_string_array(&first_row)?;