You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by al...@apache.org on 2022/11/03 13:14:26 UTC

[arrow-datafusion] branch master updated: Improve Error Handling and Readibility for downcasting `StructArray` (#4061)

This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/master by this push:
     new 761e1671b Improve Error Handling and Readibility for downcasting `StructArray` (#4061)
761e1671b is described below

commit 761e1671bd3c4988d21a38bb19e50bdac6cfaa61
Author: Burak <bu...@gmail.com>
AuthorDate: Thu Nov 3 16:14:20 2022 +0300

    Improve Error Handling and Readibility for downcasting `StructArray` (#4061)
    
    * improve error messages for StructArray
    
    * refactor newly added Date32Array downcasting and correct error string
    
    * beautify code
    
    * changes after code review
    
    * fix formatting
---
 benchmarks/src/tpch.rs                                     |  6 +++---
 datafusion/common/src/cast.rs                              | 12 +++++++++++-
 datafusion/common/src/scalar.rs                            | 14 +++-----------
 .../physical-expr/src/expressions/get_indexed_field.rs     |  5 +++--
 4 files changed, 20 insertions(+), 17 deletions(-)

diff --git a/benchmarks/src/tpch.rs b/benchmarks/src/tpch.rs
index de619af3f..bd3b3080f 100644
--- a/benchmarks/src/tpch.rs
+++ b/benchmarks/src/tpch.rs
@@ -16,8 +16,7 @@
 // under the License.
 
 use arrow::array::{
-    Array, ArrayRef, Date32Array, Decimal128Array, Float64Array, Int32Array, Int64Array,
-    StringArray,
+    Array, ArrayRef, Decimal128Array, Float64Array, Int32Array, Int64Array, StringArray,
 };
 use arrow::datatypes::SchemaRef;
 use arrow::record_batch::RecordBatch;
@@ -27,6 +26,7 @@ use std::path::Path;
 use std::sync::Arc;
 use std::time::Instant;
 
+use datafusion::common::cast::as_date32_array;
 use datafusion::common::ScalarValue;
 use datafusion::logical_expr::Cast;
 use datafusion::prelude::*;
@@ -440,7 +440,7 @@ fn col_to_scalar(column: &ArrayRef, row_index: usize) -> ScalarValue {
             ScalarValue::Decimal128(Some(array.value(row_index)), *p, *s)
         }
         DataType::Date32 => {
-            let array = column.as_any().downcast_ref::<Date32Array>().unwrap();
+            let array = as_date32_array(column).unwrap();
             ScalarValue::Date32(Some(array.value(row_index)))
         }
         DataType::Utf8 => {
diff --git a/datafusion/common/src/cast.rs b/datafusion/common/src/cast.rs
index 16a2a7422..2ce0ec224 100644
--- a/datafusion/common/src/cast.rs
+++ b/datafusion/common/src/cast.rs
@@ -21,7 +21,7 @@
 //! kernels in arrow-rs such as `as_boolean_array` do.
 
 use crate::DataFusionError;
-use arrow::array::{Array, Date32Array};
+use arrow::array::{Array, Date32Array, StructArray};
 
 // Downcast ArrayRef to Date32Array
 pub fn as_date32_array(array: &dyn Array) -> Result<&Date32Array, DataFusionError> {
@@ -32,3 +32,13 @@ pub fn as_date32_array(array: &dyn Array) -> Result<&Date32Array, DataFusionErro
         ))
     })
 }
+
+// Downcast ArrayRef to StructArray
+pub fn as_struct_array(array: &dyn Array) -> Result<&StructArray, DataFusionError> {
+    array.as_any().downcast_ref::<StructArray>().ok_or_else(|| {
+        DataFusionError::Internal(format!(
+            "Expected a StructArray, got: {}",
+            array.data_type()
+        ))
+    })
+}
diff --git a/datafusion/common/src/scalar.rs b/datafusion/common/src/scalar.rs
index 503d75c06..7cb90fd64 100644
--- a/datafusion/common/src/scalar.rs
+++ b/datafusion/common/src/scalar.rs
@@ -39,6 +39,7 @@ use arrow::{
 use chrono::{Datelike, Duration, NaiveDate, NaiveDateTime};
 use ordered_float::OrderedFloat;
 
+use crate::cast::as_struct_array;
 use crate::delta::shift_months;
 use crate::error::{DataFusionError, Result};
 
@@ -2008,15 +2009,7 @@ impl ScalarValue {
                 Self::Dictionary(key_type.clone(), Box::new(value))
             }
             DataType::Struct(fields) => {
-                let array =
-                    array
-                        .as_any()
-                        .downcast_ref::<StructArray>()
-                        .ok_or_else(|| {
-                            DataFusionError::Internal(
-                                "Failed to downcast ArrayRef to StructArray".to_string(),
-                            )
-                        })?;
+                let array = as_struct_array(array)?;
                 let mut field_values: Vec<ScalarValue> = Vec::new();
                 for col_index in 0..array.num_columns() {
                     let col_array = array.column(col_index);
@@ -3611,8 +3604,7 @@ mod tests {
         // iter_to_array for struct scalars
         let array =
             ScalarValue::iter_to_array(vec![s0.clone(), s1.clone(), s2.clone()]).unwrap();
-        let array = array.as_any().downcast_ref::<StructArray>().unwrap();
-
+        let array = as_struct_array(&array).unwrap();
         let expected = StructArray::from(vec![
             (
                 field_a.clone(),
diff --git a/datafusion/physical-expr/src/expressions/get_indexed_field.rs b/datafusion/physical-expr/src/expressions/get_indexed_field.rs
index ff10c06e2..15232fdec 100644
--- a/datafusion/physical-expr/src/expressions/get_indexed_field.rs
+++ b/datafusion/physical-expr/src/expressions/get_indexed_field.rs
@@ -19,7 +19,7 @@
 
 use crate::PhysicalExpr;
 use arrow::array::Array;
-use arrow::array::{ListArray, StructArray};
+use arrow::array::ListArray;
 use arrow::compute::concat;
 
 use crate::physical_expr::down_cast_any_ref;
@@ -27,6 +27,7 @@ use arrow::{
     datatypes::{DataType, Schema},
     record_batch::RecordBatch,
 };
+use datafusion_common::cast::as_struct_array;
 use datafusion_common::DataFusionError;
 use datafusion_common::Result;
 use datafusion_common::ScalarValue;
@@ -122,7 +123,7 @@ impl PhysicalExpr for GetIndexedFieldExpr {
                 }
             }
             (DataType::Struct(_), ScalarValue::Utf8(Some(k))) => {
-                let as_struct_array = array.as_any().downcast_ref::<StructArray>().unwrap();
+                let as_struct_array = as_struct_array(&array)?;
                 match as_struct_array.column_by_name(k) {
                     None => Err(DataFusionError::Execution(format!("get indexed field {} not found in struct", k))),
                     Some(col) => Ok(ColumnarValue::Array(col.clone()))