You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by tu...@apache.org on 2022/11/29 10:47:12 UTC

[arrow-rs] branch master updated: Support StructArray in Row Format (#3159) (#3212)

This is an automated email from the ASF dual-hosted git repository.

tustvold pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git


The following commit(s) were added to refs/heads/master by this push:
     new 733d32e90 Support StructArray in Row Format (#3159) (#3212)
733d32e90 is described below

commit 733d32e90b67bbc62bcff6fc4aa1873d43d4e686
Author: Raphael Taylor-Davies <17...@users.noreply.github.com>
AuthorDate: Tue Nov 29 10:47:07 2022 +0000

    Support StructArray in Row Format (#3159) (#3212)
    
    * Extract Codec and Encoder
    
    * Add StructArray support to Row format (#3159)
    
    * More docs
    
    * Review feedback
---
 arrow/src/row/fixed.rs |   6 +-
 arrow/src/row/mod.rs   | 515 ++++++++++++++++++++++++++++++++++++-------------
 2 files changed, 384 insertions(+), 137 deletions(-)

diff --git a/arrow/src/row/fixed.rs b/arrow/src/row/fixed.rs
index 0bad033d9..9aef83ce2 100644
--- a/arrow/src/row/fixed.rs
+++ b/arrow/src/row/fixed.rs
@@ -267,7 +267,11 @@ pub fn decode_bool(rows: &mut [&[u8]], options: SortOptions) -> BooleanArray {
     unsafe { BooleanArray::from(builder.build_unchecked()) }
 }
 
-fn decode_nulls(rows: &[&[u8]]) -> (usize, Buffer) {
+/// Decodes a single byte from each row, interpreting `0x01` as a valid value
+/// and all other values as a null
+///
+/// Returns the null count and null buffer
+pub fn decode_nulls(rows: &[&[u8]]) -> (usize, Buffer) {
     let mut null_count = 0;
     let buffer = MutableBuffer::collect_bool(rows.len(), |idx| {
         let valid = rows[idx][0] == 1;
diff --git a/arrow/src/row/mod.rs b/arrow/src/row/mod.rs
index 21d8e4df0..8572bf892 100644
--- a/arrow/src/row/mod.rs
+++ b/arrow/src/row/mod.rs
@@ -131,6 +131,7 @@ use std::sync::Arc;
 
 use arrow_array::cast::*;
 use arrow_array::*;
+use arrow_data::ArrayDataBuilder;
 
 use crate::compute::SortOptions;
 use crate::datatypes::*;
@@ -307,6 +308,31 @@ mod variable;
 ///
 ///      Input                  Row Format
 /// ```
+///
+/// ## Struct Encoding
+///
+/// A null is encoded as a `0_u8`.
+///
+/// A valid value is encoded as `1_u8` followed by the row encoding of each child.
+///
+/// This encoding effectively flattens the schema in a depth-first fashion.
+///
+/// For example
+///
+/// ```text
+/// ┌───────┬────────────────────────┬───────┐
+/// │ Int32 │ Struct[Int32, Float32] │ Int32 │
+/// └───────┴────────────────────────┴───────┘
+/// ```
+///
+/// Is encoded as
+///
+/// ```text
+/// ┌───────┬───────────────┬───────┬─────────┬───────┐
+/// │ Int32 │ Null Sentinel │ Int32 │ Float32 │ Int32 │
+/// └───────┴───────────────┴───────┴─────────┴───────┘
+/// ```
+///
 /// # Ordering
 ///
 /// ## Float Ordering
@@ -332,8 +358,103 @@ mod variable;
 #[derive(Debug)]
 pub struct RowConverter {
     fields: Arc<[SortField]>,
-    /// interning state for column `i`, if column`i` is a dictionary
-    interners: Vec<Option<Box<OrderPreservingInterner>>>,
+    /// State for codecs
+    codecs: Vec<Codec>,
+}
+
+#[derive(Debug)]
+enum Codec {
+    /// No additional codec state is necessary
+    Stateless,
+    /// The interner used to encode dictionary values
+    Dictionary(OrderPreservingInterner),
+    /// A row converter for the child fields
+    /// and the encoding of a row containing only nulls
+    Struct(RowConverter, OwnedRow),
+}
+
+impl Codec {
+    fn new(sort_field: &SortField) -> Result<Self> {
+        match &sort_field.data_type {
+            DataType::Dictionary(_, _) => Ok(Self::Dictionary(Default::default())),
+            d if !d.is_nested() => Ok(Self::Stateless),
+            DataType::Struct(f) => {
+                let sort_fields = f
+                    .iter()
+                    .map(|x| {
+                        SortField::new_with_options(
+                            x.data_type().clone(),
+                            sort_field.options,
+                        )
+                    })
+                    .collect();
+
+                let mut converter = RowConverter::new(sort_fields)?;
+                let nulls: Vec<_> =
+                    f.iter().map(|x| new_null_array(x.data_type(), 1)).collect();
+
+                let nulls = converter.convert_columns(&nulls)?;
+                let owned = OwnedRow {
+                    data: nulls.buffer,
+                    config: nulls.config,
+                };
+
+                Ok(Self::Struct(converter, owned))
+            }
+            _ => Err(ArrowError::NotYetImplemented(format!(
+                "not yet implemented: {:?}",
+                sort_field.data_type
+            ))),
+        }
+    }
+
+    fn encoder(&mut self, array: &dyn Array) -> Result<Encoder<'_>> {
+        match self {
+            Codec::Stateless => Ok(Encoder::Stateless),
+            Codec::Dictionary(interner) => {
+                let values = downcast_dictionary_array! {
+                    array => array.values(),
+                    _ => unreachable!()
+                };
+
+                let mapping = compute_dictionary_mapping(interner, values)
+                    .into_iter()
+                    .map(|maybe_interned| {
+                        maybe_interned.map(|interned| interner.normalized_key(interned))
+                    })
+                    .collect();
+
+                Ok(Encoder::Dictionary(mapping))
+            }
+            Codec::Struct(converter, null) => {
+                let v = as_struct_array(array);
+                let rows = converter.convert_columns(v.columns())?;
+                Ok(Encoder::Struct(rows, null.row()))
+            }
+        }
+    }
+
+    fn size(&self) -> usize {
+        match self {
+            Codec::Stateless => 0,
+            Codec::Dictionary(interner) => interner.size(),
+            Codec::Struct(converter, nulls) => converter.size() + nulls.data.len(),
+        }
+    }
+}
+
+#[derive(Debug)]
+enum Encoder<'a> {
+    /// No additional encoder state is necessary
+    Stateless,
+    /// The mapping from dictionary keys to normalized keys
+    Dictionary(Vec<Option<&'a [u8]>>),
+    /// The row encoding of the child array and the encoding of a null row
+    ///
+    /// It is necessary to encode to a temporary [`Rows`] to avoid serializing
+    /// values that are masked by a null in the parent StructArray, otherwise
+    /// this would establish an ordering between semantically null values
+    Struct(Rows, Row<'a>),
 }
 
 /// Configure the data type and sort order for a given column
@@ -370,21 +491,31 @@ impl RowConverter {
     pub fn new(fields: Vec<SortField>) -> Result<Self> {
         if !Self::supports_fields(&fields) {
             return Err(ArrowError::NotYetImplemented(format!(
-                "not yet implemented: {:?}",
+                "Row format support not yet implemented for: {:?}",
                 fields
             )));
         }
 
-        let interners = (0..fields.len()).map(|_| None).collect();
+        let codecs = fields.iter().map(Codec::new).collect::<Result<_>>()?;
         Ok(Self {
             fields: fields.into(),
-            interners,
+            codecs,
         })
     }
 
     /// Check if the given fields are supported by the row format.
     pub fn supports_fields(fields: &[SortField]) -> bool {
-        fields.iter().all(|x| !DataType::is_nested(&x.data_type))
+        fields.iter().all(|x| Self::supports_datatype(&x.data_type))
+    }
+
+    fn supports_datatype(d: &DataType) -> bool {
+        match d {
+            _ if !d.is_nested() => true,
+            DataType::Struct(f) => {
+                f.iter().all(|x| Self::supports_datatype(x.data_type()))
+            }
+            _ => false,
+        }
     }
 
     /// Convert [`ArrayRef`] columns into [`Rows`]
@@ -403,11 +534,11 @@ impl RowConverter {
             )));
         }
 
-        let dictionaries = columns
+        let encoders = columns
             .iter()
-            .zip(&mut self.interners)
+            .zip(&mut self.codecs)
             .zip(self.fields.iter())
-            .map(|((column, interner), field)| {
+            .map(|((column, codec), field)| {
                 if !column.data_type().equals_datatype(&field.data_type) {
                     return Err(ArrowError::InvalidArgumentError(format!(
                         "RowConverter column schema mismatch, expected {} got {}",
@@ -415,22 +546,7 @@ impl RowConverter {
                         column.data_type()
                     )));
                 }
-
-                let values = downcast_dictionary_array! {
-                    column => column.values(),
-                    _ => return Ok(None)
-                };
-
-                let interner = interner.get_or_insert_with(Default::default);
-
-                let mapping: Vec<_> = compute_dictionary_mapping(interner, values)
-                    .into_iter()
-                    .map(|maybe_interned| {
-                        maybe_interned.map(|interned| interner.normalized_key(interned))
-                    })
-                    .collect();
-
-                Ok(Some(mapping))
+                codec.encoder(column.as_ref())
             })
             .collect::<Result<Vec<_>>>()?;
 
@@ -439,13 +555,13 @@ impl RowConverter {
             // Don't need to validate UTF-8 as came from arrow array
             validate_utf8: false,
         };
-        let mut rows = new_empty_rows(columns, &dictionaries, config);
+        let mut rows = new_empty_rows(columns, &encoders, config);
 
-        for ((column, field), dictionary) in
-            columns.iter().zip(self.fields.iter()).zip(dictionaries)
+        for ((column, field), encoder) in
+            columns.iter().zip(self.fields.iter()).zip(encoders)
         {
             // We encode a column at a time to minimise dispatch overheads
-            encode_column(&mut rows, column, field.options, dictionary.as_deref())
+            encode_column(&mut rows, column, field.options, &encoder)
         }
 
         if cfg!(debug_assertions) {
@@ -480,17 +596,26 @@ impl RowConverter {
             })
             .collect();
 
+        // SAFETY
+        // We have validated that the rows came from this [`RowConverter`]
+        // and therefore must be valid
+        unsafe { self.convert_raw(&mut rows, validate_utf8) }
+    }
+
+    /// Convert raw bytes into [`ArrayRef`]
+    ///
+    /// # Safety
+    ///
+    /// `rows` must contain valid data for this [`RowConverter`]
+    unsafe fn convert_raw(
+        &self,
+        rows: &mut [&[u8]],
+        validate_utf8: bool,
+    ) -> Result<Vec<ArrayRef>> {
         self.fields
             .iter()
-            .zip(&self.interners)
-            .map(|(field, interner)| {
-                // SAFETY
-                // We have validated that the rows came from this [`RowConverter`]
-                // and therefore must be valid
-                unsafe {
-                    decode_column(field, &mut rows, interner.as_deref(), validate_utf8)
-                }
-            })
+            .zip(&self.codecs)
+            .map(|(field, codec)| decode_column(field, rows, codec, validate_utf8))
             .collect()
     }
 
@@ -505,13 +630,8 @@ impl RowConverter {
     pub fn size(&self) -> usize {
         std::mem::size_of::<Self>()
             + self.fields.iter().map(|x| x.size()).sum::<usize>()
-            + self.interners.capacity()
-                * std::mem::size_of::<Option<Box<OrderPreservingInterner>>>()
-            + self
-                .interners
-                .iter()
-                .filter_map(|x| x.as_ref().map(|x| x.size()))
-                .sum::<usize>()
+            + self.codecs.capacity() * std::mem::size_of::<Codec>()
+            + self.codecs.iter().map(Codec::size).sum::<usize>()
     }
 }
 
@@ -668,7 +788,7 @@ impl<'a> Row<'a> {
     /// Create owned version of the row to detach it from the shared [`Rows`].
     pub fn owned(&self) -> OwnedRow {
         OwnedRow {
-            data: self.data.to_vec(),
+            data: self.data.into(),
             config: self.config.clone(),
         }
     }
@@ -718,7 +838,7 @@ impl<'a> AsRef<[u8]> for Row<'a> {
 /// This contains the data for the one specific row (not the entire buffer of all rows).
 #[derive(Debug, Clone)]
 pub struct OwnedRow {
-    data: Vec<u8>,
+    data: Box<[u8]>,
     config: RowConfig,
 }
 
@@ -783,54 +903,64 @@ fn null_sentinel(options: SortOptions) -> u8 {
 }
 
 /// Computes the length of each encoded [`Rows`] and returns an empty [`Rows`]
-fn new_empty_rows(
-    cols: &[ArrayRef],
-    dictionaries: &[Option<Vec<Option<&[u8]>>>],
-    config: RowConfig,
-) -> Rows {
+fn new_empty_rows(cols: &[ArrayRef], encoders: &[Encoder], config: RowConfig) -> Rows {
     use fixed::FixedLengthEncoding;
 
     let num_rows = cols.first().map(|x| x.len()).unwrap_or(0);
     let mut lengths = vec![0; num_rows];
 
-    for (array, dict) in cols.iter().zip(dictionaries) {
-        downcast_primitive_array! {
-            array => lengths.iter_mut().for_each(|x| *x += fixed::encoded_len(array)),
-            DataType::Null => {},
-            DataType::Boolean => lengths.iter_mut().for_each(|x| *x += bool::ENCODED_LEN),
-            DataType::Binary => as_generic_binary_array::<i32>(array)
-                .iter()
-                .zip(lengths.iter_mut())
-                .for_each(|(slice, length)| *length += variable::encoded_len(slice)),
-            DataType::LargeBinary => as_generic_binary_array::<i64>(array)
-                .iter()
-                .zip(lengths.iter_mut())
-                .for_each(|(slice, length)| *length += variable::encoded_len(slice)),
-            DataType::Utf8 => as_string_array(array)
-                .iter()
-                .zip(lengths.iter_mut())
-                .for_each(|(slice, length)| {
-                    *length += variable::encoded_len(slice.map(|x| x.as_bytes()))
-                }),
-            DataType::LargeUtf8 => as_largestring_array(array)
-                .iter()
-                .zip(lengths.iter_mut())
-                .for_each(|(slice, length)| {
-                    *length += variable::encoded_len(slice.map(|x| x.as_bytes()))
-                }),
-            DataType::Dictionary(_, _) => downcast_dictionary_array! {
-                array => {
-                    let dict = dict.as_ref().unwrap();
-                    for (v, length) in array.keys().iter().zip(lengths.iter_mut()) {
-                        match v.and_then(|v| dict[v as usize]) {
-                            Some(k) => *length += k.len() + 1,
-                            None => *length += 1,
+    for (array, encoder) in cols.iter().zip(encoders) {
+        match encoder {
+            Encoder::Stateless => {
+                downcast_primitive_array! {
+                    array => lengths.iter_mut().for_each(|x| *x += fixed::encoded_len(array)),
+                    DataType::Null => {},
+                    DataType::Boolean => lengths.iter_mut().for_each(|x| *x += bool::ENCODED_LEN),
+                    DataType::Binary => as_generic_binary_array::<i32>(array)
+                        .iter()
+                        .zip(lengths.iter_mut())
+                        .for_each(|(slice, length)| *length += variable::encoded_len(slice)),
+                    DataType::LargeBinary => as_generic_binary_array::<i64>(array)
+                        .iter()
+                        .zip(lengths.iter_mut())
+                        .for_each(|(slice, length)| *length += variable::encoded_len(slice)),
+                    DataType::Utf8 => as_string_array(array)
+                        .iter()
+                        .zip(lengths.iter_mut())
+                        .for_each(|(slice, length)| {
+                            *length += variable::encoded_len(slice.map(|x| x.as_bytes()))
+                        }),
+                    DataType::LargeUtf8 => as_largestring_array(array)
+                        .iter()
+                        .zip(lengths.iter_mut())
+                        .for_each(|(slice, length)| {
+                            *length += variable::encoded_len(slice.map(|x| x.as_bytes()))
+                        }),
+                    _ => unreachable!(),
+                }
+            }
+            Encoder::Dictionary(dict) => {
+                downcast_dictionary_array! {
+                    array => {
+                        for (v, length) in array.keys().iter().zip(lengths.iter_mut()) {
+                            match v.and_then(|v| dict[v as usize]) {
+                                Some(k) => *length += k.len() + 1,
+                                None => *length += 1,
+                            }
                         }
                     }
+                    _ => unreachable!(),
                 }
-                _ => unreachable!(),
             }
-            _ => unreachable!(),
+            Encoder::Struct(rows, null) => {
+                let array = as_struct_array(array);
+                lengths.iter_mut().enumerate().for_each(|(idx, length)| {
+                    match array.is_valid(idx) {
+                        true => *length += 1 + rows.row(idx).as_ref().len(),
+                        false => *length += 1 + null.data.len(),
+                    }
+                });
+            }
         }
     }
 
@@ -872,35 +1002,59 @@ fn encode_column(
     out: &mut Rows,
     column: &ArrayRef,
     opts: SortOptions,
-    dictionary: Option<&[Option<&[u8]>]>,
+    encoder: &Encoder<'_>,
 ) {
-    downcast_primitive_array! {
-        column => fixed::encode(out, column, opts),
-        DataType::Null => {}
-        DataType::Boolean => fixed::encode(out, as_boolean_array(column), opts),
-        DataType::Binary => {
-            variable::encode(out, as_generic_binary_array::<i32>(column).iter(), opts)
+    match encoder {
+        Encoder::Stateless => {
+            downcast_primitive_array! {
+                column => fixed::encode(out, column, opts),
+                DataType::Null => {}
+                DataType::Boolean => fixed::encode(out, as_boolean_array(column), opts),
+                DataType::Binary => {
+                    variable::encode(out, as_generic_binary_array::<i32>(column).iter(), opts)
+                }
+                DataType::LargeBinary => {
+                    variable::encode(out, as_generic_binary_array::<i64>(column).iter(), opts)
+                }
+                DataType::Utf8 => variable::encode(
+                    out,
+                    as_string_array(column).iter().map(|x| x.map(|x| x.as_bytes())),
+                    opts,
+                ),
+                DataType::LargeUtf8 => variable::encode(
+                    out,
+                    as_largestring_array(column)
+                        .iter()
+                        .map(|x| x.map(|x| x.as_bytes())),
+                    opts,
+                ),
+                _ => unreachable!(),
+            }
         }
-        DataType::LargeBinary => {
-            variable::encode(out, as_generic_binary_array::<i64>(column).iter(), opts)
+        Encoder::Dictionary(dict) => {
+            downcast_dictionary_array! {
+                column => encode_dictionary(out, column, dict, opts),
+                _ => unreachable!()
+            }
         }
-        DataType::Utf8 => variable::encode(
-            out,
-            as_string_array(column).iter().map(|x| x.map(|x| x.as_bytes())),
-            opts,
-        ),
-        DataType::LargeUtf8 => variable::encode(
-            out,
-            as_largestring_array(column)
-                .iter()
-                .map(|x| x.map(|x| x.as_bytes())),
-            opts,
-        ),
-        DataType::Dictionary(_, _) => downcast_dictionary_array! {
-            column => encode_dictionary(out, column, dictionary.unwrap(), opts),
-            _ => unreachable!()
+        Encoder::Struct(rows, null) => {
+            let array = as_struct_array(column.as_ref());
+            let null_sentinel = null_sentinel(opts);
+            out.offsets
+                .iter_mut()
+                .skip(1)
+                .enumerate()
+                .for_each(|(idx, offset)| {
+                    let (row, sentinel) = match array.is_valid(idx) {
+                        true => (rows.row(idx), 0x01),
+                        false => (*null, null_sentinel),
+                    };
+                    let end_offset = *offset + 1 + row.as_ref().len();
+                    out.buffer[*offset] = sentinel;
+                    out.buffer[*offset + 1..end_offset].copy_from_slice(row.as_ref());
+                    *offset = end_offset;
+                })
         }
-        _ => unreachable!(),
     }
 }
 
@@ -912,12 +1066,7 @@ macro_rules! decode_primitive_helper {
 
 macro_rules! decode_dictionary_helper {
     ($t:ty, $interner:ident, $v:ident, $options:ident, $rows:ident) => {
-        Arc::new(decode_dictionary::<$t>(
-            $interner.unwrap(),
-            $v.as_ref(),
-            $options,
-            $rows,
-        )?)
+        Arc::new(decode_dictionary::<$t>($interner, $v, $options, $rows)?)
     };
 }
 
@@ -929,28 +1078,73 @@ macro_rules! decode_dictionary_helper {
 unsafe fn decode_column(
     field: &SortField,
     rows: &mut [&[u8]],
-    interner: Option<&OrderPreservingInterner>,
+    codec: &Codec,
     validate_utf8: bool,
 ) -> Result<ArrayRef> {
     let options = field.options;
-    let data_type = field.data_type.clone();
-    let array: ArrayRef = downcast_primitive! {
-        data_type => (decode_primitive_helper, rows, data_type, options),
-        DataType::Null => Arc::new(NullArray::new(rows.len())),
-        DataType::Boolean => Arc::new(decode_bool(rows, options)),
-        DataType::Binary => Arc::new(decode_binary::<i32>(rows, options)),
-        DataType::LargeBinary => Arc::new(decode_binary::<i64>(rows, options)),
-        DataType::Utf8 => Arc::new(decode_string::<i32>(rows, options, validate_utf8)),
-        DataType::LargeUtf8 => Arc::new(decode_string::<i64>(rows, options, validate_utf8)),
-        DataType::Dictionary(k, v) => downcast_integer! {
-            k.as_ref() => (decode_dictionary_helper, interner, v, options, rows),
-            _ => unreachable!()
-        },
-        _ => {
-            return Err(ArrowError::NotYetImplemented(format!(
-                "converting {} row is not supported",
-                field.data_type
-            )))
+
+    let array: ArrayRef = match codec {
+        Codec::Stateless => {
+            let data_type = field.data_type.clone();
+            downcast_primitive! {
+                data_type => (decode_primitive_helper, rows, data_type, options),
+                DataType::Null => Arc::new(NullArray::new(rows.len())),
+                DataType::Boolean => Arc::new(decode_bool(rows, options)),
+                DataType::Binary => Arc::new(decode_binary::<i32>(rows, options)),
+                DataType::LargeBinary => Arc::new(decode_binary::<i64>(rows, options)),
+                DataType::Utf8 => Arc::new(decode_string::<i32>(rows, options, validate_utf8)),
+                DataType::LargeUtf8 => Arc::new(decode_string::<i64>(rows, options, validate_utf8)),
+                _ => unreachable!()
+            }
+        }
+        Codec::Dictionary(interner) => {
+            let (k, v) = match &field.data_type {
+                DataType::Dictionary(k, v) => (k.as_ref(), v.as_ref()),
+                _ => unreachable!(),
+            };
+            downcast_integer! {
+                k => (decode_dictionary_helper, interner, v, options, rows),
+                _ => unreachable!()
+            }
+        }
+        Codec::Struct(converter, _) => {
+            let child_fields = match &field.data_type {
+                DataType::Struct(f) => f,
+                _ => unreachable!(),
+            };
+
+            let (null_count, nulls) = fixed::decode_nulls(rows);
+            rows.iter_mut().for_each(|row| *row = &row[1..]);
+            let children = converter.convert_raw(rows, validate_utf8)?;
+
+            let child_data = child_fields
+                .iter()
+                .zip(&children)
+                .map(|(f, c)| {
+                    let data = c.data().clone();
+                    match f.is_nullable() {
+                        true => data,
+                        false => {
+                            assert_eq!(data.null_count(), null_count);
+                            // Need to strip out null buffer if any as this is created
+                            // as an artifact of the row encoding process that encodes
+                            // nulls from the parent struct array in the children
+                            data.into_builder()
+                                .null_count(0)
+                                .null_bit_buffer(None)
+                                .build_unchecked()
+                        }
+                    }
+                })
+                .collect();
+
+            let builder = ArrayDataBuilder::new(field.data_type.clone())
+                .len(rows.len())
+                .null_count(null_count)
+                .null_bit_buffer(Some(nulls))
+                .child_data(child_data);
+
+            Arc::new(StructArray::from(builder.build_unchecked()))
         }
     };
     Ok(array)
@@ -965,6 +1159,7 @@ mod tests {
     use rand::{thread_rng, Rng};
 
     use arrow_array::NullArray;
+    use arrow_buffer::Buffer;
 
     use crate::array::{
         BinaryArray, BooleanArray, DictionaryArray, Float32Array, GenericStringArray,
@@ -1329,6 +1524,54 @@ mod tests {
         assert_eq!(&cols[0], &a);
     }
 
+    #[test]
+    fn test_struct() {
+        // Test basic
+        let a = Arc::new(Int32Array::from(vec![1, 1, 2, 2])) as ArrayRef;
+        let a_f = Field::new("int", DataType::Int32, false);
+        let u = Arc::new(StringArray::from(vec!["a", "b", "c", "d"])) as ArrayRef;
+        let u_f = Field::new("s", DataType::Utf8, false);
+        let s1 = Arc::new(StructArray::from(vec![(a_f, a), (u_f, u)])) as ArrayRef;
+
+        let sort_fields = vec![SortField::new(s1.data_type().clone())];
+        let mut converter = RowConverter::new(sort_fields).unwrap();
+        let r1 = converter.convert_columns(&[Arc::clone(&s1)]).unwrap();
+
+        for (a, b) in r1.iter().zip(r1.iter().skip(1)) {
+            assert!(a < b);
+        }
+
+        let back = converter.convert_rows(&r1).unwrap();
+        assert_eq!(back.len(), 1);
+        assert_eq!(&back[0], &s1);
+
+        // Test struct nullability
+        let data = s1
+            .data()
+            .clone()
+            .into_builder()
+            .null_bit_buffer(Some(Buffer::from_slice_ref([0b00001010])))
+            .null_count(2)
+            .build()
+            .unwrap();
+
+        let s2 = Arc::new(StructArray::from(data)) as ArrayRef;
+        let r2 = converter.convert_columns(&[Arc::clone(&s2)]).unwrap();
+        assert_eq!(r2.row(0), r2.row(2)); // Nulls equal
+        assert!(r2.row(0) < r2.row(1)); // Nulls first
+        assert_ne!(r1.row(0), r2.row(0)); // Value does not equal null
+        assert_eq!(r1.row(1), r2.row(1)); // Values equal
+
+        let back = converter.convert_rows(&r2).unwrap();
+        assert_eq!(back.len(), 1);
+        assert_eq!(&back[0], &s2);
+        let back_s = as_struct_array(&back[0]);
+        for c in back_s.columns() {
+            // Children should not contain nulls
+            assert_eq!(c.null_count(), 0);
+        }
+    }
+
     #[test]
     fn test_primitive_dictionary() {
         let mut builder = PrimitiveDictionaryBuilder::<Int32Type, Int32Type>::new();