You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by al...@apache.org on 2023/04/22 12:33:22 UTC
[arrow-datafusion] branch main updated: fix: null handling of `ScalarValue::Struct` (#6085)
This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new 0336621a7b fix: null handling of `ScalarValue::Struct` (#6085)
0336621a7b is described below
commit 0336621a7b635ba3fa97f42a50a99b4d3aee6f7e
Author: Marco Neumann <ma...@crepererum.net>
AuthorDate: Sat Apr 22 14:33:16 2023 +0200
fix: null handling of `ScalarValue::Struct` (#6085)
Fixes #6083.
---
datafusion/common/src/scalar.rs | 105 +++++++++++++++++++++++++++++++++++-----
1 file changed, 93 insertions(+), 12 deletions(-)
diff --git a/datafusion/common/src/scalar.rs b/datafusion/common/src/scalar.rs
index f313e662da..e56042b38e 100644
--- a/datafusion/common/src/scalar.rs
+++ b/datafusion/common/src/scalar.rs
@@ -31,6 +31,7 @@ use crate::cast::{
};
use crate::delta::shift_months;
use crate::error::{DataFusionError, Result};
+use arrow::compute::nullif;
use arrow::datatypes::{FieldRef, Fields, SchemaBuilder};
use arrow::{
array::*,
@@ -2338,6 +2339,9 @@ impl ScalarValue {
let mut columns: Vec<Vec<ScalarValue>> =
(0..fields.len()).map(|_| Vec::new()).collect();
+ // null mask
+ let mut null_mask_builder = BooleanBuilder::new();
+
// Iterate over scalars to populate the column scalars for each row
for scalar in scalars {
if let ScalarValue::Struct(values, fields) = scalar {
@@ -2347,6 +2351,7 @@ impl ScalarValue {
for (column, value) in columns.iter_mut().zip(values) {
column.push(value.clone());
}
+ null_mask_builder.append_value(false);
}
None => {
// Push NULL of the appropriate type for each field
@@ -2356,6 +2361,7 @@ impl ScalarValue {
column
.push(ScalarValue::try_from(field.data_type())?);
}
+ null_mask_builder.append_value(true);
}
};
} else {
@@ -2374,7 +2380,8 @@ impl ScalarValue {
})
.collect::<Result<Vec<_>>>()?;
- Arc::new(StructArray::from(field_values))
+ let array = StructArray::from(field_values);
+ nullif(&array, &null_mask_builder.finish())?
}
DataType::Dictionary(key_type, value_type) => {
// create the values array
@@ -2777,16 +2784,8 @@ impl ScalarValue {
Arc::new(StructArray::from(field_values))
}
None => {
- let field_values: Vec<_> = fields
- .iter()
- .map(|field| {
- let none_field = Self::try_from(field.data_type())
- .expect("Failed to construct null ScalarValue from Struct field type");
- (field.as_ref().clone(), none_field.to_array_of_size(size))
- })
- .collect();
-
- Arc::new(StructArray::from(field_values))
+ let dt = self.get_datatype();
+ new_null_array(&dt, size)
}
},
ScalarValue::Dictionary(key_type, v) => {
@@ -3715,9 +3714,10 @@ mod tests {
use std::cmp::Ordering;
use std::sync::Arc;
- use arrow::compute;
use arrow::compute::kernels;
+ use arrow::compute::{self, concat, is_null};
use arrow::datatypes::ArrowPrimitiveType;
+ use arrow::util::pretty::pretty_format_columns;
use arrow_array::ArrowNumericType;
use rand::Rng;
@@ -5584,6 +5584,87 @@ mod tests {
}
}
+ #[test]
+ fn test_struct_nulls() {
+ let fields_b = Fields::from(vec![
+ Field::new("ba", DataType::UInt64, true),
+ Field::new("bb", DataType::UInt64, true),
+ ]);
+ let fields = Fields::from(vec![
+ Field::new("a", DataType::UInt64, true),
+ Field::new("b", DataType::Struct(fields_b.clone()), true),
+ ]);
+ let scalars = vec![
+ ScalarValue::Struct(None, fields.clone()),
+ ScalarValue::Struct(
+ Some(vec![
+ ScalarValue::UInt64(None),
+ ScalarValue::Struct(None, fields_b.clone()),
+ ]),
+ fields.clone(),
+ ),
+ ScalarValue::Struct(
+ Some(vec![
+ ScalarValue::UInt64(None),
+ ScalarValue::Struct(
+ Some(vec![ScalarValue::UInt64(None), ScalarValue::UInt64(None)]),
+ fields_b.clone(),
+ ),
+ ]),
+ fields.clone(),
+ ),
+ ScalarValue::Struct(
+ Some(vec![
+ ScalarValue::UInt64(Some(1)),
+ ScalarValue::Struct(
+ Some(vec![
+ ScalarValue::UInt64(Some(2)),
+ ScalarValue::UInt64(Some(3)),
+ ]),
+ fields_b,
+ ),
+ ]),
+ fields,
+ ),
+ ];
+
+ let check_array = |array| {
+ let is_null = is_null(&array).unwrap();
+ assert_eq!(is_null, BooleanArray::from(vec![true, false, false, false]));
+
+ let formatted = pretty_format_columns("col", &[array]).unwrap().to_string();
+ let formatted = formatted.split('\n').collect::<Vec<_>>();
+ let expected = vec![
+ "+---------------------------+",
+ "| col |",
+ "+---------------------------+",
+ "| |",
+ "| {a: , b: } |",
+ "| {a: , b: {ba: , bb: }} |",
+ "| {a: 1, b: {ba: 2, bb: 3}} |",
+ "+---------------------------+",
+ ];
+ assert_eq!(
+ formatted, expected,
+ "Actual:\n{:#?}\n\nExpected:\n{:#?}",
+ formatted, expected
+ );
+ };
+
+ // test `ScalarValue::iter_to_array`
+ let array = ScalarValue::iter_to_array(scalars.clone()).unwrap();
+ check_array(array);
+
+ // test `ScalarValue::to_array` / `ScalarValue::to_array_of_size`
+ let arrays = scalars
+ .iter()
+ .map(ScalarValue::to_array)
+ .collect::<Vec<_>>();
+ let arrays = arrays.iter().map(|a| a.as_ref()).collect::<Vec<_>>();
+ let array = concat(&arrays).unwrap();
+ check_array(array);
+ }
+
fn get_timestamp_test_data(
sign: i32,
) -> Vec<(ScalarValue, ScalarValue, ScalarValue)> {