You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by tu...@apache.org on 2022/11/10 21:24:32 UTC

[arrow-rs] branch master updated: early type checks in `RowConverter` (#3080)

This is an automated email from the ASF dual-hosted git repository.

tustvold pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git


The following commit(s) were added to refs/heads/master by this push:
     new 8d364fe43 early type checks in `RowConverter` (#3080)
8d364fe43 is described below

commit 8d364fe430c39d99ed68665c8c4223e02f54ab56
Author: Marco Neumann <ma...@crepererum.net>
AuthorDate: Thu Nov 10 21:24:27 2022 +0000

    early type checks in `RowConverter` (#3080)
    
    * refactor: remove duplicate code
    
    Decimal types are already handled by `downcast_primitive`.
    
    * refactor: check supported types when creating `RowConverter`
    
    Check supported row format types when creating the converter instead of
    during conversion. Also add an additional method
    `RowConverter::supports_fields` to check types w/o relying on an error.
    
    Closes #3077.
    
    * Simplify
    
    Co-authored-by: Raphael Taylor-Davies <r....@googlemail.com>
---
 arrow/benches/lexsort.rs    |  2 +-
 arrow/benches/row_format.rs |  4 +--
 arrow/src/row/dictionary.rs | 17 +++-------
 arrow/src/row/mod.rs        | 83 +++++++++++++++++++++++++++++----------------
 4 files changed, 61 insertions(+), 45 deletions(-)

diff --git a/arrow/benches/lexsort.rs b/arrow/benches/lexsort.rs
index aebb588cf..5c161ec8d 100644
--- a/arrow/benches/lexsort.rs
+++ b/arrow/benches/lexsort.rs
@@ -105,7 +105,7 @@ fn do_bench(c: &mut Criterion, columns: &[Column], len: usize) {
                     .iter()
                     .map(|a| SortField::new(a.data_type().clone()))
                     .collect();
-                let mut converter = RowConverter::new(fields);
+                let mut converter = RowConverter::new(fields).unwrap();
                 let rows = converter.convert_columns(&arrays).unwrap();
                 let mut sort: Vec<_> = rows.iter().enumerate().collect();
                 sort.sort_unstable_by(|(_, a), (_, b)| a.cmp(b));
diff --git a/arrow/benches/row_format.rs b/arrow/benches/row_format.rs
index 48bb01311..ac9f3106f 100644
--- a/arrow/benches/row_format.rs
+++ b/arrow/benches/row_format.rs
@@ -38,12 +38,12 @@ fn do_bench(c: &mut Criterion, name: &str, cols: Vec<ArrayRef>) {
 
     c.bench_function(&format!("convert_columns {}", name), |b| {
         b.iter(|| {
-            let mut converter = RowConverter::new(fields.clone());
+            let mut converter = RowConverter::new(fields.clone()).unwrap();
             black_box(converter.convert_columns(&cols).unwrap())
         });
     });
 
-    let mut converter = RowConverter::new(fields);
+    let mut converter = RowConverter::new(fields).unwrap();
     let rows = converter.convert_columns(&cols).unwrap();
     // using a pre-prepared row converter should be faster than the first time
     c.bench_function(&format!("convert_columns_prepared {}", name), |b| {
diff --git a/arrow/src/row/dictionary.rs b/arrow/src/row/dictionary.rs
index 950a7d897..d8426ad0c 100644
--- a/arrow/src/row/dictionary.rs
+++ b/arrow/src/row/dictionary.rs
@@ -33,8 +33,8 @@ use std::collections::HashMap;
 pub fn compute_dictionary_mapping(
     interner: &mut OrderPreservingInterner,
     values: &ArrayRef,
-) -> Result<Vec<Option<Interned>>, ArrowError> {
-    Ok(downcast_primitive_array! {
+) -> Vec<Option<Interned>> {
+    downcast_primitive_array! {
         values => interner
             .intern(values.iter().map(|x| x.map(|x| x.encode()))),
         DataType::Binary => {
@@ -53,8 +53,8 @@ pub fn compute_dictionary_mapping(
             let iter = as_largestring_array(values).iter().map(|x| x.map(|x| x.as_bytes()));
             interner.intern(iter)
         }
-        t => return Err(ArrowError::NotYetImplemented(format!("dictionary value {} is not supported", t))),
-    })
+        _ => unreachable!(),
+    }
 }
 
 /// Dictionary types are encoded as
@@ -173,18 +173,11 @@ pub unsafe fn decode_dictionary<K: ArrowDictionaryKeyType>(
         value_type => (decode_primitive_helper, values, value_type),
         DataType::Null => NullArray::new(values.len()).into_data(),
         DataType::Boolean => decode_bool(&values),
-        DataType::Decimal128(_, _) => decode_primitive_helper!(Decimal128Type, values, value_type),
-        DataType::Decimal256(_, _) => decode_primitive_helper!(Decimal256Type, values, value_type),
         DataType::Utf8 => decode_string::<i32>(&values),
         DataType::LargeUtf8 => decode_string::<i64>(&values),
         DataType::Binary => decode_binary::<i32>(&values),
         DataType::LargeBinary => decode_binary::<i64>(&values),
-        _ => {
-            return Err(ArrowError::NotYetImplemented(format!(
-                "decoding dictionary values of {}",
-                value_type
-            )))
-        }
+        _ => unreachable!(),
     };
 
     let data_type =
diff --git a/arrow/src/row/mod.rs b/arrow/src/row/mod.rs
index e0312be1f..4fbaa3931 100644
--- a/arrow/src/row/mod.rs
+++ b/arrow/src/row/mod.rs
@@ -40,7 +40,7 @@
 //! let mut converter = RowConverter::new(vec![
 //!     SortField::new(DataType::Int32),
 //!     SortField::new(DataType::Utf8),
-//! ]);
+//! ]).unwrap();
 //! let rows = converter.convert_columns(&arrays).unwrap();
 //!
 //! // Compare rows
@@ -83,7 +83,7 @@
 //!         .iter()
 //!         .map(|a| SortField::new(a.data_type().clone()))
 //!         .collect();
-//!     let mut converter = RowConverter::new(fields);
+//!     let mut converter = RowConverter::new(fields).unwrap();
 //!     let rows = converter.convert_columns(&arrays).unwrap();
 //!     let mut sort: Vec<_> = rows.iter().enumerate().collect();
 //!     sort.sort_unstable_by(|(_, a), (_, b)| a.cmp(b));
@@ -231,12 +231,24 @@ impl SortField {
 
 impl RowConverter {
     /// Create a new [`RowConverter`] with the provided schema
-    pub fn new(fields: Vec<SortField>) -> Self {
+    pub fn new(fields: Vec<SortField>) -> Result<Self> {
+        if !Self::supports_fields(&fields) {
+            return Err(ArrowError::NotYetImplemented(format!(
+                "not yet implemented: {:?}",
+                fields
+            )));
+        }
+
         let interners = (0..fields.len()).map(|_| None).collect();
-        Self {
+        Ok(Self {
             fields: fields.into(),
             interners,
-        }
+        })
+    }
+
+    /// Check if the given fields are supported by the row format.
+    pub fn supports_fields(fields: &[SortField]) -> bool {
+        fields.iter().all(|x| !DataType::is_nested(&x.data_type))
     }
 
     /// Convert [`ArrayRef`] columns into [`Rows`]
@@ -275,7 +287,7 @@ impl RowConverter {
 
                 let interner = interner.get_or_insert_with(Default::default);
 
-                let mapping: Vec<_> = compute_dictionary_mapping(interner, values)?
+                let mapping: Vec<_> = compute_dictionary_mapping(interner, values)
                     .into_iter()
                     .map(|maybe_interned| {
                         maybe_interned.map(|interned| interner.normalized_key(interned))
@@ -286,7 +298,7 @@ impl RowConverter {
             })
             .collect::<Result<Vec<_>>>()?;
 
-        let mut rows = new_empty_rows(columns, &dictionaries, Arc::clone(&self.fields))?;
+        let mut rows = new_empty_rows(columns, &dictionaries, Arc::clone(&self.fields));
 
         for ((column, field), dictionary) in
             columns.iter().zip(self.fields.iter()).zip(dictionaries)
@@ -492,7 +504,7 @@ fn new_empty_rows(
     cols: &[ArrayRef],
     dictionaries: &[Option<Vec<Option<&[u8]>>>],
     fields: Arc<[SortField]>,
-) -> Result<Rows> {
+) -> Rows {
     use fixed::FixedLengthEncoding;
 
     let num_rows = cols.first().map(|x| x.len()).unwrap_or(0);
@@ -535,7 +547,7 @@ fn new_empty_rows(
                 }
                 _ => unreachable!(),
             }
-            t => return Err(ArrowError::NotYetImplemented(format!("not yet implemented: {}", t)))
+            _ => unreachable!(),
         }
     }
 
@@ -565,11 +577,11 @@ fn new_empty_rows(
 
     let buffer = vec![0_u8; cur_offset];
 
-    Ok(Rows {
+    Rows {
         buffer: buffer.into(),
         offsets: offsets.into(),
         fields,
-    })
+    }
 }
 
 /// Encodes a column to the provided [`Rows`] incrementing the offsets as it progresses
@@ -605,7 +617,7 @@ fn encode_column(
             column => encode_dictionary(out, column, dictionary.unwrap(), opts),
             _ => unreachable!()
         }
-        t => unimplemented!("not yet implemented: {}", t)
+        _ => unreachable!(),
     }
 }
 
@@ -747,7 +759,8 @@ mod tests {
         let mut converter = RowConverter::new(vec![
             SortField::new(DataType::Int16),
             SortField::new(DataType::Float32),
-        ]);
+        ])
+        .unwrap();
         let rows = converter.convert_columns(&cols).unwrap();
 
         assert_eq!(rows.offsets.as_ref(), &[0, 8, 16, 24, 32, 40, 48, 56]);
@@ -787,7 +800,8 @@ mod tests {
     fn test_decimal128() {
         let mut converter = RowConverter::new(vec![SortField::new(
             DataType::Decimal128(DECIMAL128_MAX_PRECISION, 7),
-        )]);
+        )])
+        .unwrap();
         let col = Arc::new(
             Decimal128Array::from_iter([
                 None,
@@ -815,7 +829,8 @@ mod tests {
     fn test_decimal256() {
         let mut converter = RowConverter::new(vec![SortField::new(
             DataType::Decimal256(DECIMAL256_MAX_PRECISION, 7),
-        )]);
+        )])
+        .unwrap();
         let col = Arc::new(
             Decimal256Array::from_iter([
                 None,
@@ -843,7 +858,8 @@ mod tests {
 
     #[test]
     fn test_bool() {
-        let mut converter = RowConverter::new(vec![SortField::new(DataType::Boolean)]);
+        let mut converter =
+            RowConverter::new(vec![SortField::new(DataType::Boolean)]).unwrap();
 
         let col = Arc::new(BooleanArray::from_iter([None, Some(false), Some(true)]))
             as ArrayRef;
@@ -862,7 +878,8 @@ mod tests {
                 descending: true,
                 nulls_first: false,
             },
-        )]);
+        )])
+        .unwrap();
 
         let rows = converter.convert_columns(&[Arc::clone(&col)]).unwrap();
         assert!(rows.row(2) < rows.row(1));
@@ -879,7 +896,7 @@ mod tests {
         let d = a.data_type().clone();
 
         let mut converter =
-            RowConverter::new(vec![SortField::new(a.data_type().clone())]);
+            RowConverter::new(vec![SortField::new(a.data_type().clone())]).unwrap();
         let rows = converter.convert_columns(&[Arc::new(a) as _]).unwrap();
         let back = converter.convert_rows(&rows).unwrap();
         assert_eq!(back.len(), 1);
@@ -905,7 +922,7 @@ mod tests {
         );
 
         assert_eq!(dict_with_tz.data_type(), &d);
-        let mut converter = RowConverter::new(vec![SortField::new(d.clone())]);
+        let mut converter = RowConverter::new(vec![SortField::new(d.clone())]).unwrap();
         let rows = converter
             .convert_columns(&[Arc::new(dict_with_tz) as _])
             .unwrap();
@@ -917,7 +934,8 @@ mod tests {
     #[test]
     fn test_null_encoding() {
         let col = Arc::new(NullArray::new(10));
-        let mut converter = RowConverter::new(vec![SortField::new(DataType::Null)]);
+        let mut converter =
+            RowConverter::new(vec![SortField::new(DataType::Null)]).unwrap();
         let rows = converter.convert_columns(&[col]).unwrap();
         assert_eq!(rows.num_rows(), 10);
         assert_eq!(rows.row(1).data.len(), 0);
@@ -933,7 +951,8 @@ mod tests {
             Some(""),
         ])) as ArrayRef;
 
-        let mut converter = RowConverter::new(vec![SortField::new(DataType::Utf8)]);
+        let mut converter =
+            RowConverter::new(vec![SortField::new(DataType::Utf8)]).unwrap();
         let rows = converter.convert_columns(&[Arc::clone(&col)]).unwrap();
 
         assert!(rows.row(1) < rows.row(0));
@@ -958,7 +977,8 @@ mod tests {
             Some(vec![0xFF_u8; variable::BLOCK_SIZE + 1]),
         ])) as ArrayRef;
 
-        let mut converter = RowConverter::new(vec![SortField::new(DataType::Binary)]);
+        let mut converter =
+            RowConverter::new(vec![SortField::new(DataType::Binary)]).unwrap();
         let rows = converter.convert_columns(&[Arc::clone(&col)]).unwrap();
 
         for i in 0..rows.num_rows() {
@@ -983,7 +1003,8 @@ mod tests {
                 descending: true,
                 nulls_first: false,
             },
-        )]);
+        )])
+        .unwrap();
         let rows = converter.convert_columns(&[Arc::clone(&col)]).unwrap();
 
         for i in 0..rows.num_rows() {
@@ -1017,7 +1038,7 @@ mod tests {
         ])) as ArrayRef;
 
         let mut converter =
-            RowConverter::new(vec![SortField::new(a.data_type().clone())]);
+            RowConverter::new(vec![SortField::new(a.data_type().clone())]).unwrap();
         let rows_a = converter.convert_columns(&[Arc::clone(&a)]).unwrap();
 
         assert!(rows_a.row(3) < rows_a.row(5));
@@ -1052,7 +1073,8 @@ mod tests {
                 descending: true,
                 nulls_first: false,
             },
-        )]);
+        )])
+        .unwrap();
 
         let rows_c = converter.convert_columns(&[Arc::clone(&a)]).unwrap();
         assert!(rows_c.row(3) > rows_c.row(5));
@@ -1078,7 +1100,7 @@ mod tests {
         let a = builder.finish();
 
         let mut converter =
-            RowConverter::new(vec![SortField::new(a.data_type().clone())]);
+            RowConverter::new(vec![SortField::new(a.data_type().clone())]).unwrap();
         let rows = converter.convert_columns(&[Arc::new(a)]).unwrap();
         assert!(rows.row(0) < rows.row(1));
         assert!(rows.row(2) < rows.row(0));
@@ -1104,7 +1126,7 @@ mod tests {
             .build()
             .unwrap();
 
-        let mut converter = RowConverter::new(vec![SortField::new(data_type)]);
+        let mut converter = RowConverter::new(vec![SortField::new(data_type)]).unwrap();
         let rows = converter
             .convert_columns(&[Arc::new(DictionaryArray::<Int32Type>::from(data))])
             .unwrap();
@@ -1119,10 +1141,11 @@ mod tests {
     #[should_panic(expected = "rows were not produced by this RowConverter")]
     fn test_different_converter() {
         let values = Arc::new(Int32Array::from_iter([Some(1), Some(-1)]));
-        let mut converter = RowConverter::new(vec![SortField::new(DataType::Int32)]);
+        let mut converter =
+            RowConverter::new(vec![SortField::new(DataType::Int32)]).unwrap();
         let rows = converter.convert_columns(&[values]).unwrap();
 
-        let converter = RowConverter::new(vec![SortField::new(DataType::Int32)]);
+        let converter = RowConverter::new(vec![SortField::new(DataType::Int32)]).unwrap();
         let _ = converter.convert_rows(&rows);
     }
 
@@ -1266,7 +1289,7 @@ mod tests {
                 .map(|(o, a)| SortField::new_with_options(a.data_type().clone(), o))
                 .collect();
 
-            let mut converter = RowConverter::new(columns);
+            let mut converter = RowConverter::new(columns).unwrap();
             let rows = converter.convert_columns(&arrays).unwrap();
 
             for i in 0..len {