You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by al...@apache.org on 2022/11/22 22:22:53 UTC

[arrow-datafusion] branch master updated: Refactor downcasting functions with downcastvalue macro and improve error handling of `ListArray` downcasting (#4313)

This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/master by this push:
     new bfce07652 Refactor downcasting functions with downcastvalue macro and improve error handling of `ListArray` downcasting (#4313)
bfce07652 is described below

commit bfce076527e40e51357b3815ed367a28e8b73b3a
Author: Burak <bu...@gmail.com>
AuthorDate: Wed Nov 23 01:22:47 2022 +0300

    Refactor downcasting functions with downcastvalue macro and improve error handling of `ListArray` downcasting (#4313)
    
    * refactor casting with downcastvalue macro and add list array downcasting
    
    * fix clippy
---
 datafusion/common/src/cast.rs                      | 99 +++++-----------------
 datafusion/common/src/scalar.rs                    | 17 ++--
 .../core/src/avro_to_arrow/arrow_array_reader.rs   | 10 +--
 datafusion/core/tests/sql/parquet.rs               | 14 +--
 .../physical-expr/src/aggregate/count_distinct.rs  |  6 +-
 .../src/expressions/get_indexed_field.rs           | 11 +--
 datafusion/physical-expr/src/functions.rs          | 10 +--
 7 files changed, 41 insertions(+), 126 deletions(-)

diff --git a/datafusion/common/src/cast.rs b/datafusion/common/src/cast.rs
index 940168947..3d5036d7e 100644
--- a/datafusion/common/src/cast.rs
+++ b/datafusion/common/src/cast.rs
@@ -20,132 +20,71 @@
 //! but provide an error message rather than a panic, as the corresponding
 //! kernels in arrow-rs such as `as_boolean_array` do.
 
-use crate::DataFusionError;
+use crate::{downcast_value, DataFusionError};
 use arrow::array::{
     Array, BooleanArray, Date32Array, Decimal128Array, Float32Array, Float64Array,
-    Int32Array, Int64Array, StringArray, StructArray, UInt32Array, UInt64Array,
+    Int32Array, Int64Array, ListArray, StringArray, StructArray, UInt32Array,
+    UInt64Array,
 };
 
 // Downcast ArrayRef to Date32Array
 pub fn as_date32_array(array: &dyn Array) -> Result<&Date32Array, DataFusionError> {
-    array.as_any().downcast_ref::<Date32Array>().ok_or_else(|| {
-        DataFusionError::Internal(format!(
-            "Expected a Date32Array, got: {}",
-            array.data_type()
-        ))
-    })
+    Ok(downcast_value!(array, Date32Array))
 }
 
 // Downcast ArrayRef to StructArray
 pub fn as_struct_array(array: &dyn Array) -> Result<&StructArray, DataFusionError> {
-    array.as_any().downcast_ref::<StructArray>().ok_or_else(|| {
-        DataFusionError::Internal(format!(
-            "Expected a StructArray, got: {}",
-            array.data_type()
-        ))
-    })
+    Ok(downcast_value!(array, StructArray))
 }
 
 // Downcast ArrayRef to Int32Array
 pub fn as_int32_array(array: &dyn Array) -> Result<&Int32Array, DataFusionError> {
-    array.as_any().downcast_ref::<Int32Array>().ok_or_else(|| {
-        DataFusionError::Internal(format!(
-            "Expected a Int32Array, got: {}",
-            array.data_type()
-        ))
-    })
+    Ok(downcast_value!(array, Int32Array))
 }
 
 // Downcast ArrayRef to Int64Array
 pub fn as_int64_array(array: &dyn Array) -> Result<&Int64Array, DataFusionError> {
-    array.as_any().downcast_ref::<Int64Array>().ok_or_else(|| {
-        DataFusionError::Internal(format!(
-            "Expected a Int64Array, got: {}",
-            array.data_type()
-        ))
-    })
+    Ok(downcast_value!(array, Int64Array))
 }
 
 // Downcast ArrayRef to Decimal128Array
 pub fn as_decimal128_array(
     array: &dyn Array,
 ) -> Result<&Decimal128Array, DataFusionError> {
-    array
-        .as_any()
-        .downcast_ref::<Decimal128Array>()
-        .ok_or_else(|| {
-            DataFusionError::Internal(format!(
-                "Expected a Decimal128Array, got: {}",
-                array.data_type()
-            ))
-        })
+    Ok(downcast_value!(array, Decimal128Array))
 }
 
 // Downcast ArrayRef to Float32Array
 pub fn as_float32_array(array: &dyn Array) -> Result<&Float32Array, DataFusionError> {
-    array
-        .as_any()
-        .downcast_ref::<Float32Array>()
-        .ok_or_else(|| {
-            DataFusionError::Internal(format!(
-                "Expected a Float32Array, got: {}",
-                array.data_type()
-            ))
-        })
+    Ok(downcast_value!(array, Float32Array))
 }
 
 // Downcast ArrayRef to Float64Array
 pub fn as_float64_array(array: &dyn Array) -> Result<&Float64Array, DataFusionError> {
-    array
-        .as_any()
-        .downcast_ref::<Float64Array>()
-        .ok_or_else(|| {
-            DataFusionError::Internal(format!(
-                "Expected a Float64Array, got: {}",
-                array.data_type()
-            ))
-        })
+    Ok(downcast_value!(array, Float64Array))
 }
 
 // Downcast ArrayRef to StringArray
 pub fn as_string_array(array: &dyn Array) -> Result<&StringArray, DataFusionError> {
-    array.as_any().downcast_ref::<StringArray>().ok_or_else(|| {
-        DataFusionError::Internal(format!(
-            "Expected a StringArray, got: {}",
-            array.data_type()
-        ))
-    })
+    Ok(downcast_value!(array, StringArray))
 }
 
 // Downcast ArrayRef to UInt32Array
 pub fn as_uint32_array(array: &dyn Array) -> Result<&UInt32Array, DataFusionError> {
-    array.as_any().downcast_ref::<UInt32Array>().ok_or_else(|| {
-        DataFusionError::Internal(format!(
-            "Expected a UInt32Array, got: {}",
-            array.data_type()
-        ))
-    })
+    Ok(downcast_value!(array, UInt32Array))
 }
 
 // Downcast ArrayRef to UInt64Array
 pub fn as_uint64_array(array: &dyn Array) -> Result<&UInt64Array, DataFusionError> {
-    array.as_any().downcast_ref::<UInt64Array>().ok_or_else(|| {
-        DataFusionError::Internal(format!(
-            "Expected a UInt64Array, got: {}",
-            array.data_type()
-        ))
-    })
+    Ok(downcast_value!(array, UInt64Array))
 }
 
 // Downcast ArrayRef to BooleanArray
 pub fn as_boolean_array(array: &dyn Array) -> Result<&BooleanArray, DataFusionError> {
-    array
-        .as_any()
-        .downcast_ref::<BooleanArray>()
-        .ok_or_else(|| {
-            DataFusionError::Internal(format!(
-                "Expected a BooleanArray, got: {}",
-                array.data_type()
-            ))
-        })
+    Ok(downcast_value!(array, BooleanArray))
+}
+
+// Downcast ArrayRef to ListArray
+pub fn as_list_array(array: &dyn Array) -> Result<&ListArray, DataFusionError> {
+    Ok(downcast_value!(array, ListArray))
 }
diff --git a/datafusion/common/src/scalar.rs b/datafusion/common/src/scalar.rs
index 44bf278b9..96d1ab672 100644
--- a/datafusion/common/src/scalar.rs
+++ b/datafusion/common/src/scalar.rs
@@ -24,7 +24,7 @@ use std::ops::{Add, Sub};
 use std::str::FromStr;
 use std::{convert::TryFrom, fmt, iter::repeat, sync::Arc};
 
-use crate::cast::{as_decimal128_array, as_struct_array};
+use crate::cast::{as_decimal128_array, as_list_array, as_struct_array};
 use crate::delta::shift_months;
 use crate::error::{DataFusionError, Result};
 use arrow::{
@@ -2001,12 +2001,7 @@ impl ScalarValue {
             DataType::Utf8 => typed_cast!(array, index, StringArray, Utf8),
             DataType::LargeUtf8 => typed_cast!(array, index, LargeStringArray, LargeUtf8),
             DataType::List(nested_type) => {
-                let list_array =
-                    array.as_any().downcast_ref::<ListArray>().ok_or_else(|| {
-                        DataFusionError::Internal(
-                            "Failed to downcast ListArray".to_string(),
-                        )
-                    })?;
+                let list_array = as_list_array(array)?;
                 let value = match list_array.is_null(index) {
                     true => None,
                     false => {
@@ -2940,7 +2935,7 @@ mod tests {
             Box::new(Field::new("item", DataType::UInt64, false)),
         )
         .to_array();
-        let list_array = list_array_ref.as_any().downcast_ref::<ListArray>().unwrap();
+        let list_array = as_list_array(&list_array_ref).unwrap();
 
         assert!(list_array.is_null(0));
         assert_eq!(list_array.len(), 1);
@@ -2959,7 +2954,7 @@ mod tests {
         )
         .to_array();
 
-        let list_array = list_array_ref.as_any().downcast_ref::<ListArray>().unwrap();
+        let list_array = as_list_array(&list_array_ref)?;
         assert_eq!(list_array.len(), 1);
         assert_eq!(list_array.values().len(), 3);
 
@@ -3758,7 +3753,7 @@ mod tests {
         let nl2 = ScalarValue::new_list(Some(vec![s1]), s0.get_datatype());
         // iter_to_array for list-of-struct
         let array = ScalarValue::iter_to_array(vec![nl0, nl1, nl2]).unwrap();
-        let array = array.as_any().downcast_ref::<ListArray>().unwrap();
+        let array = as_list_array(&array).unwrap();
 
         // Construct expected array with array builders
         let field_a_builder = StringBuilder::with_capacity(4, 1024);
@@ -3922,7 +3917,7 @@ mod tests {
         );
 
         let array = ScalarValue::iter_to_array(vec![l1, l2, l3]).unwrap();
-        let array = array.as_any().downcast_ref::<ListArray>().unwrap();
+        let array = as_list_array(&array).unwrap();
 
         // Construct expected array with array builders
         let inner_builder = Int32Array::builder(8);
diff --git a/datafusion/core/src/avro_to_arrow/arrow_array_reader.rs b/datafusion/core/src/avro_to_arrow/arrow_array_reader.rs
index f40c5045a..30c32fc57 100644
--- a/datafusion/core/src/avro_to_arrow/arrow_array_reader.rs
+++ b/datafusion/core/src/avro_to_arrow/arrow_array_reader.rs
@@ -975,9 +975,9 @@ mod test {
     use crate::arrow::array::Array;
     use crate::arrow::datatypes::{Field, TimeUnit};
     use crate::avro_to_arrow::{Reader, ReaderBuilder};
-    use arrow::array::{ListArray, TimestampMicrosecondArray};
+    use arrow::array::TimestampMicrosecondArray;
     use arrow::datatypes::DataType;
-    use datafusion_common::cast::{as_int32_array, as_int64_array};
+    use datafusion_common::cast::{as_int32_array, as_int64_array, as_list_array};
     use std::fs::File;
 
     fn build_reader(name: &str, batch_size: usize) -> Reader<File> {
@@ -1034,11 +1034,7 @@ mod test {
         let batch = reader.next().unwrap().unwrap();
         assert_eq!(batch.num_columns(), 2);
         assert_eq!(batch.num_rows(), 3);
-        let a_array = batch
-            .column(col_id_index)
-            .as_any()
-            .downcast_ref::<ListArray>()
-            .unwrap();
+        let a_array = as_list_array(batch.column(col_id_index)).unwrap();
         assert_eq!(
             *a_array.data_type(),
             DataType::List(Box::new(Field::new("bigint", DataType::Int64, true)))
diff --git a/datafusion/core/tests/sql/parquet.rs b/datafusion/core/tests/sql/parquet.rs
index e7b1fe5b1..7cf4f343f 100644
--- a/datafusion/core/tests/sql/parquet.rs
+++ b/datafusion/core/tests/sql/parquet.rs
@@ -19,7 +19,7 @@ use std::{fs, path::Path};
 
 use ::parquet::arrow::ArrowWriter;
 use datafusion::datasource::listing::ListingOptions;
-use datafusion_common::cast::as_string_array;
+use datafusion_common::cast::{as_list_array, as_string_array};
 use tempfile::TempDir;
 
 use super::*;
@@ -235,16 +235,8 @@ async fn parquet_list_columns() {
     assert_eq!(2, batch.num_columns());
     assert_eq!(schema, batch.schema());
 
-    let int_list_array = batch
-        .column(0)
-        .as_any()
-        .downcast_ref::<ListArray>()
-        .unwrap();
-    let utf8_list_array = batch
-        .column(1)
-        .as_any()
-        .downcast_ref::<ListArray>()
-        .unwrap();
+    let int_list_array = as_list_array(batch.column(0)).unwrap();
+    let utf8_list_array = as_list_array(batch.column(1)).unwrap();
 
     assert_eq!(
         int_list_array
diff --git a/datafusion/physical-expr/src/aggregate/count_distinct.rs b/datafusion/physical-expr/src/aggregate/count_distinct.rs
index 943f7b632..d4c0b4406 100644
--- a/datafusion/physical-expr/src/aggregate/count_distinct.rs
+++ b/datafusion/physical-expr/src/aggregate/count_distinct.rs
@@ -226,11 +226,11 @@ mod tests {
     use crate::aggregate::utils::get_accum_scalar_values;
     use arrow::array::{
         ArrayRef, BooleanArray, Float32Array, Float64Array, Int16Array, Int32Array,
-        Int64Array, Int8Array, ListArray, UInt16Array, UInt32Array, UInt64Array,
-        UInt8Array,
+        Int64Array, Int8Array, UInt16Array, UInt32Array, UInt64Array, UInt8Array,
     };
     use arrow::array::{Int32Builder, ListBuilder, UInt64Builder};
     use arrow::datatypes::DataType;
+    use datafusion_common::cast::as_list_array;
 
     macro_rules! state_to_vec {
         ($LIST:expr, $DATA_TYPE:ident, $PRIM_TY:ty) => {{
@@ -380,7 +380,7 @@ mod tests {
         let agg = DistinctCount::new(
             arrays
                 .iter()
-                .map(|a| a.as_any().downcast_ref::<ListArray>().unwrap())
+                .map(|a| as_list_array(a).unwrap())
                 .map(|a| a.values().data_type().clone())
                 .collect::<Vec<_>>(),
             vec![],
diff --git a/datafusion/physical-expr/src/expressions/get_indexed_field.rs b/datafusion/physical-expr/src/expressions/get_indexed_field.rs
index 2b77a0b95..8fbb68d61 100644
--- a/datafusion/physical-expr/src/expressions/get_indexed_field.rs
+++ b/datafusion/physical-expr/src/expressions/get_indexed_field.rs
@@ -19,7 +19,6 @@
 
 use crate::PhysicalExpr;
 use arrow::array::Array;
-use arrow::array::ListArray;
 use arrow::compute::concat;
 
 use crate::physical_expr::down_cast_any_ref;
@@ -27,7 +26,7 @@ use arrow::{
     datatypes::{DataType, Schema},
     record_batch::RecordBatch,
 };
-use datafusion_common::cast::as_struct_array;
+use datafusion_common::cast::{as_list_array, as_struct_array};
 use datafusion_common::DataFusionError;
 use datafusion_common::Result;
 use datafusion_common::ScalarValue;
@@ -91,8 +90,7 @@ impl PhysicalExpr for GetIndexedFieldExpr {
                 Ok(ColumnarValue::Scalar(scalar_null))
             }
             (DataType::List(_), ScalarValue::Int64(Some(i))) => {
-                let as_list_array =
-                    array.as_any().downcast_ref::<ListArray>().unwrap();
+                let as_list_array = as_list_array(&array)?;
 
                 if *i < 1 || as_list_array.is_empty() {
                     let scalar_null: ScalarValue = array.data_type().try_into()?;
@@ -349,10 +347,7 @@ mod tests {
         let get_list_expr =
             Arc::new(GetIndexedFieldExpr::new(struct_col_expr, list_field_key));
         let result = get_list_expr.evaluate(&batch)?.into_array(batch.num_rows());
-        let result = result
-            .as_any()
-            .downcast_ref::<ListArray>()
-            .unwrap_or_else(|| panic!("failed to downcast to ListArray : {:?}", result));
+        let result = as_list_array(&result)?;
         let expected =
             &build_utf8_lists(list_of_tuples.into_iter().map(|t| t.1).collect());
         assert_eq!(expected, result);
diff --git a/datafusion/physical-expr/src/functions.rs b/datafusion/physical-expr/src/functions.rs
index c84ee24c5..1ed83b89a 100644
--- a/datafusion/physical-expr/src/functions.rs
+++ b/datafusion/physical-expr/src/functions.rs
@@ -2847,8 +2847,7 @@ mod tests {
     #[test]
     #[cfg(feature = "regex_expressions")]
     fn test_regexp_match() -> Result<()> {
-        use arrow::array::ListArray;
-        use datafusion_common::cast::as_string_array;
+        use datafusion_common::cast::{as_list_array, as_string_array};
         let schema = Schema::new(vec![Field::new("a", DataType::Utf8, false)]);
         let execution_props = ExecutionProps::new();
 
@@ -2873,7 +2872,7 @@ mod tests {
         let result = expr.evaluate(&batch)?.into_array(batch.num_rows());
 
         // downcast works
-        let result = result.as_any().downcast_ref::<ListArray>().unwrap();
+        let result = as_list_array(&result)?;
         let first_row = result.value(0);
         let first_row = as_string_array(&first_row)?;
 
@@ -2887,8 +2886,7 @@ mod tests {
     #[test]
     #[cfg(feature = "regex_expressions")]
     fn test_regexp_match_all_literals() -> Result<()> {
-        use arrow::array::ListArray;
-        use datafusion_common::cast::as_string_array;
+        use datafusion_common::cast::{as_list_array, as_string_array};
         let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
         let execution_props = ExecutionProps::new();
 
@@ -2913,7 +2911,7 @@ mod tests {
         let result = expr.evaluate(&batch)?.into_array(batch.num_rows());
 
         // downcast works
-        let result = result.as_any().downcast_ref::<ListArray>().unwrap();
+        let result = as_list_array(&result)?;
         let first_row = result.value(0);
         let first_row = as_string_array(&first_row)?;