You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by al...@apache.org on 2023/04/22 12:33:22 UTC

[arrow-datafusion] branch main updated: fix: null handling of `ScalarValue::Struct` (#6085)

This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new 0336621a7b fix: null handling of `ScalarValue::Struct` (#6085)
0336621a7b is described below

commit 0336621a7b635ba3fa97f42a50a99b4d3aee6f7e
Author: Marco Neumann <ma...@crepererum.net>
AuthorDate: Sat Apr 22 14:33:16 2023 +0200

    fix: null handling of `ScalarValue::Struct` (#6085)
    
    Fixes #6083.
---
 datafusion/common/src/scalar.rs | 105 +++++++++++++++++++++++++++++++++++-----
 1 file changed, 93 insertions(+), 12 deletions(-)

diff --git a/datafusion/common/src/scalar.rs b/datafusion/common/src/scalar.rs
index f313e662da..e56042b38e 100644
--- a/datafusion/common/src/scalar.rs
+++ b/datafusion/common/src/scalar.rs
@@ -31,6 +31,7 @@ use crate::cast::{
 };
 use crate::delta::shift_months;
 use crate::error::{DataFusionError, Result};
+use arrow::compute::nullif;
 use arrow::datatypes::{FieldRef, Fields, SchemaBuilder};
 use arrow::{
     array::*,
@@ -2338,6 +2339,9 @@ impl ScalarValue {
                 let mut columns: Vec<Vec<ScalarValue>> =
                     (0..fields.len()).map(|_| Vec::new()).collect();
 
+                // null mask
+                let mut null_mask_builder = BooleanBuilder::new();
+
                 // Iterate over scalars to populate the column scalars for each row
                 for scalar in scalars {
                     if let ScalarValue::Struct(values, fields) = scalar {
@@ -2347,6 +2351,7 @@ impl ScalarValue {
                                 for (column, value) in columns.iter_mut().zip(values) {
                                     column.push(value.clone());
                                 }
+                                null_mask_builder.append_value(false);
                             }
                             None => {
                                 // Push NULL of the appropriate type for each field
@@ -2356,6 +2361,7 @@ impl ScalarValue {
                                     column
                                         .push(ScalarValue::try_from(field.data_type())?);
                                 }
+                                null_mask_builder.append_value(true);
                             }
                         };
                     } else {
@@ -2374,7 +2380,8 @@ impl ScalarValue {
                     })
                     .collect::<Result<Vec<_>>>()?;
 
-                Arc::new(StructArray::from(field_values))
+                let array = StructArray::from(field_values);
+                nullif(&array, &null_mask_builder.finish())?
             }
             DataType::Dictionary(key_type, value_type) => {
                 // create the values array
@@ -2777,16 +2784,8 @@ impl ScalarValue {
                     Arc::new(StructArray::from(field_values))
                 }
                 None => {
-                    let field_values: Vec<_> = fields
-                        .iter()
-                        .map(|field| {
-                            let none_field = Self::try_from(field.data_type())
-                                .expect("Failed to construct null ScalarValue from Struct field type");
-                            (field.as_ref().clone(), none_field.to_array_of_size(size))
-                        })
-                        .collect();
-
-                    Arc::new(StructArray::from(field_values))
+                    let dt = self.get_datatype();
+                    new_null_array(&dt, size)
                 }
             },
             ScalarValue::Dictionary(key_type, v) => {
@@ -3715,9 +3714,10 @@ mod tests {
     use std::cmp::Ordering;
     use std::sync::Arc;
 
-    use arrow::compute;
     use arrow::compute::kernels;
+    use arrow::compute::{self, concat, is_null};
     use arrow::datatypes::ArrowPrimitiveType;
+    use arrow::util::pretty::pretty_format_columns;
     use arrow_array::ArrowNumericType;
     use rand::Rng;
 
@@ -5584,6 +5584,87 @@ mod tests {
         }
     }
 
+    #[test]
+    fn test_struct_nulls() {
+        let fields_b = Fields::from(vec![
+            Field::new("ba", DataType::UInt64, true),
+            Field::new("bb", DataType::UInt64, true),
+        ]);
+        let fields = Fields::from(vec![
+            Field::new("a", DataType::UInt64, true),
+            Field::new("b", DataType::Struct(fields_b.clone()), true),
+        ]);
+        let scalars = vec![
+            ScalarValue::Struct(None, fields.clone()),
+            ScalarValue::Struct(
+                Some(vec![
+                    ScalarValue::UInt64(None),
+                    ScalarValue::Struct(None, fields_b.clone()),
+                ]),
+                fields.clone(),
+            ),
+            ScalarValue::Struct(
+                Some(vec![
+                    ScalarValue::UInt64(None),
+                    ScalarValue::Struct(
+                        Some(vec![ScalarValue::UInt64(None), ScalarValue::UInt64(None)]),
+                        fields_b.clone(),
+                    ),
+                ]),
+                fields.clone(),
+            ),
+            ScalarValue::Struct(
+                Some(vec![
+                    ScalarValue::UInt64(Some(1)),
+                    ScalarValue::Struct(
+                        Some(vec![
+                            ScalarValue::UInt64(Some(2)),
+                            ScalarValue::UInt64(Some(3)),
+                        ]),
+                        fields_b,
+                    ),
+                ]),
+                fields,
+            ),
+        ];
+
+        let check_array = |array| {
+            let is_null = is_null(&array).unwrap();
+            assert_eq!(is_null, BooleanArray::from(vec![true, false, false, false]));
+
+            let formatted = pretty_format_columns("col", &[array]).unwrap().to_string();
+            let formatted = formatted.split('\n').collect::<Vec<_>>();
+            let expected = vec![
+                "+---------------------------+",
+                "| col                       |",
+                "+---------------------------+",
+                "|                           |",
+                "| {a: , b: }                |",
+                "| {a: , b: {ba: , bb: }}    |",
+                "| {a: 1, b: {ba: 2, bb: 3}} |",
+                "+---------------------------+",
+            ];
+            assert_eq!(
+                formatted, expected,
+                "Actual:\n{:#?}\n\nExpected:\n{:#?}",
+                formatted, expected
+            );
+        };
+
+        // test `ScalarValue::iter_to_array`
+        let array = ScalarValue::iter_to_array(scalars.clone()).unwrap();
+        check_array(array);
+
+        // test `ScalarValue::to_array` / `ScalarValue::to_array_of_size`
+        let arrays = scalars
+            .iter()
+            .map(ScalarValue::to_array)
+            .collect::<Vec<_>>();
+        let arrays = arrays.iter().map(|a| a.as_ref()).collect::<Vec<_>>();
+        let array = concat(&arrays).unwrap();
+        check_array(array);
+    }
+
     fn get_timestamp_test_data(
         sign: i32,
     ) -> Vec<(ScalarValue, ScalarValue, ScalarValue)> {