You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by al...@apache.org on 2022/11/03 13:14:26 UTC
[arrow-datafusion] branch master updated: Improve Error Handling and Readibility for downcasting `StructArray` (#4061)
This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/master by this push:
new 761e1671b Improve Error Handling and Readibility for downcasting `StructArray` (#4061)
761e1671b is described below
commit 761e1671bd3c4988d21a38bb19e50bdac6cfaa61
Author: Burak <bu...@gmail.com>
AuthorDate: Thu Nov 3 16:14:20 2022 +0300
Improve Error Handling and Readibility for downcasting `StructArray` (#4061)
* improve error messages for StructArray
* refactor newly added Date32Array downcasting and correct error string
* beautify code
* changes after code review
* fix formatting
---
benchmarks/src/tpch.rs | 6 +++---
datafusion/common/src/cast.rs | 12 +++++++++++-
datafusion/common/src/scalar.rs | 14 +++-----------
.../physical-expr/src/expressions/get_indexed_field.rs | 5 +++--
4 files changed, 20 insertions(+), 17 deletions(-)
diff --git a/benchmarks/src/tpch.rs b/benchmarks/src/tpch.rs
index de619af3f..bd3b3080f 100644
--- a/benchmarks/src/tpch.rs
+++ b/benchmarks/src/tpch.rs
@@ -16,8 +16,7 @@
// under the License.
use arrow::array::{
- Array, ArrayRef, Date32Array, Decimal128Array, Float64Array, Int32Array, Int64Array,
- StringArray,
+ Array, ArrayRef, Decimal128Array, Float64Array, Int32Array, Int64Array, StringArray,
};
use arrow::datatypes::SchemaRef;
use arrow::record_batch::RecordBatch;
@@ -27,6 +26,7 @@ use std::path::Path;
use std::sync::Arc;
use std::time::Instant;
+use datafusion::common::cast::as_date32_array;
use datafusion::common::ScalarValue;
use datafusion::logical_expr::Cast;
use datafusion::prelude::*;
@@ -440,7 +440,7 @@ fn col_to_scalar(column: &ArrayRef, row_index: usize) -> ScalarValue {
ScalarValue::Decimal128(Some(array.value(row_index)), *p, *s)
}
DataType::Date32 => {
- let array = column.as_any().downcast_ref::<Date32Array>().unwrap();
+ let array = as_date32_array(column).unwrap();
ScalarValue::Date32(Some(array.value(row_index)))
}
DataType::Utf8 => {
diff --git a/datafusion/common/src/cast.rs b/datafusion/common/src/cast.rs
index 16a2a7422..2ce0ec224 100644
--- a/datafusion/common/src/cast.rs
+++ b/datafusion/common/src/cast.rs
@@ -21,7 +21,7 @@
//! kernels in arrow-rs such as `as_boolean_array` do.
use crate::DataFusionError;
-use arrow::array::{Array, Date32Array};
+use arrow::array::{Array, Date32Array, StructArray};
// Downcast ArrayRef to Date32Array
pub fn as_date32_array(array: &dyn Array) -> Result<&Date32Array, DataFusionError> {
@@ -32,3 +32,13 @@ pub fn as_date32_array(array: &dyn Array) -> Result<&Date32Array, DataFusionErro
))
})
}
+
+// Downcast ArrayRef to StructArray
+pub fn as_struct_array(array: &dyn Array) -> Result<&StructArray, DataFusionError> {
+ array.as_any().downcast_ref::<StructArray>().ok_or_else(|| {
+ DataFusionError::Internal(format!(
+ "Expected a StructArray, got: {}",
+ array.data_type()
+ ))
+ })
+}
diff --git a/datafusion/common/src/scalar.rs b/datafusion/common/src/scalar.rs
index 503d75c06..7cb90fd64 100644
--- a/datafusion/common/src/scalar.rs
+++ b/datafusion/common/src/scalar.rs
@@ -39,6 +39,7 @@ use arrow::{
use chrono::{Datelike, Duration, NaiveDate, NaiveDateTime};
use ordered_float::OrderedFloat;
+use crate::cast::as_struct_array;
use crate::delta::shift_months;
use crate::error::{DataFusionError, Result};
@@ -2008,15 +2009,7 @@ impl ScalarValue {
Self::Dictionary(key_type.clone(), Box::new(value))
}
DataType::Struct(fields) => {
- let array =
- array
- .as_any()
- .downcast_ref::<StructArray>()
- .ok_or_else(|| {
- DataFusionError::Internal(
- "Failed to downcast ArrayRef to StructArray".to_string(),
- )
- })?;
+ let array = as_struct_array(array)?;
let mut field_values: Vec<ScalarValue> = Vec::new();
for col_index in 0..array.num_columns() {
let col_array = array.column(col_index);
@@ -3611,8 +3604,7 @@ mod tests {
// iter_to_array for struct scalars
let array =
ScalarValue::iter_to_array(vec![s0.clone(), s1.clone(), s2.clone()]).unwrap();
- let array = array.as_any().downcast_ref::<StructArray>().unwrap();
-
+ let array = as_struct_array(&array).unwrap();
let expected = StructArray::from(vec![
(
field_a.clone(),
diff --git a/datafusion/physical-expr/src/expressions/get_indexed_field.rs b/datafusion/physical-expr/src/expressions/get_indexed_field.rs
index ff10c06e2..15232fdec 100644
--- a/datafusion/physical-expr/src/expressions/get_indexed_field.rs
+++ b/datafusion/physical-expr/src/expressions/get_indexed_field.rs
@@ -19,7 +19,7 @@
use crate::PhysicalExpr;
use arrow::array::Array;
-use arrow::array::{ListArray, StructArray};
+use arrow::array::ListArray;
use arrow::compute::concat;
use crate::physical_expr::down_cast_any_ref;
@@ -27,6 +27,7 @@ use arrow::{
datatypes::{DataType, Schema},
record_batch::RecordBatch,
};
+use datafusion_common::cast::as_struct_array;
use datafusion_common::DataFusionError;
use datafusion_common::Result;
use datafusion_common::ScalarValue;
@@ -122,7 +123,7 @@ impl PhysicalExpr for GetIndexedFieldExpr {
}
}
(DataType::Struct(_), ScalarValue::Utf8(Some(k))) => {
- let as_struct_array = array.as_any().downcast_ref::<StructArray>().unwrap();
+ let as_struct_array = as_struct_array(&array)?;
match as_struct_array.column_by_name(k) {
None => Err(DataFusionError::Execution(format!("get indexed field {} not found in struct", k))),
Some(col) => Ok(ColumnarValue::Array(col.clone()))