You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by tu...@apache.org on 2022/11/25 07:06:34 UTC

[arrow-rs] branch master updated: Row decode cleanups (#3180)

This is an automated email from the ASF dual-hosted git repository.

tustvold pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git


The following commit(s) were added to refs/heads/master by this push:
     new d74c48e05 Row decode cleanups (#3180)
d74c48e05 is described below

commit d74c48e0541aba2941daf6ea2ce8dce84619bda5
Author: Raphael Taylor-Davies <17...@users.noreply.github.com>
AuthorDate: Fri Nov 25 07:06:30 2022 +0000

    Row decode cleanups (#3180)
    
    * Row decode cleanups
    
    * Clippy
---
 arrow/src/row/fixed.rs | 65 +++++++++++++++++----------------------------
 arrow/src/row/mod.rs   | 71 +++++++++++---------------------------------------
 2 files changed, 39 insertions(+), 97 deletions(-)

diff --git a/arrow/src/row/fixed.rs b/arrow/src/row/fixed.rs
index 76bf358e7..0bad033d9 100644
--- a/arrow/src/row/fixed.rs
+++ b/arrow/src/row/fixed.rs
@@ -19,8 +19,9 @@ use crate::array::PrimitiveArray;
 use crate::compute::SortOptions;
 use crate::datatypes::ArrowPrimitiveType;
 use crate::row::{null_sentinel, Rows};
+use arrow_array::builder::BufferBuilder;
 use arrow_array::BooleanArray;
-use arrow_buffer::{bit_util, i256, MutableBuffer, ToByteSlice};
+use arrow_buffer::{bit_util, i256, ArrowNativeType, Buffer, MutableBuffer};
 use arrow_data::{ArrayData, ArrayDataBuilder};
 use arrow_schema::DataType;
 use half::f16;
@@ -266,61 +267,43 @@ pub fn decode_bool(rows: &mut [&[u8]], options: SortOptions) -> BooleanArray {
     unsafe { BooleanArray::from(builder.build_unchecked()) }
 }
 
+fn decode_nulls(rows: &[&[u8]]) -> (usize, Buffer) {
+    let mut null_count = 0;
+    let buffer = MutableBuffer::collect_bool(rows.len(), |idx| {
+        let valid = rows[idx][0] == 1;
+        null_count += !valid as usize;
+        valid
+    })
+    .into();
+    (null_count, buffer)
+}
+
 /// Decodes a `ArrayData` from rows based on the provided `FixedLengthEncoding` `T`
 ///
 /// # Safety
 ///
 /// `data_type` must be appropriate native type for `T`
-unsafe fn decode_fixed<T: FixedLengthEncoding + ToByteSlice>(
+unsafe fn decode_fixed<T: FixedLengthEncoding + ArrowNativeType>(
     rows: &mut [&[u8]],
     data_type: DataType,
     options: SortOptions,
 ) -> ArrayData {
     let len = rows.len();
 
-    let mut null_count = 0;
-    let mut nulls = MutableBuffer::new(bit_util::ceil(len, 64) * 8);
-    let mut values = MutableBuffer::new(std::mem::size_of::<T>() * len);
+    let mut values = BufferBuilder::<T>::new(len);
+    let (null_count, nulls) = decode_nulls(rows);
 
-    let chunks = len / 64;
-    let remainder = len % 64;
-    for chunk in 0..chunks {
-        let mut null_packed = 0;
-
-        for bit_idx in 0..64 {
-            let i = split_off(&mut rows[bit_idx + chunk * 64], T::ENCODED_LEN);
-            let null = i[0] == 1;
-            null_count += !null as usize;
-            null_packed |= (null as u64) << bit_idx;
-
-            let value = T::Encoded::from_slice(&i[1..], options.descending);
-            values.push(T::decode(value));
-        }
-
-        nulls.push(null_packed);
-    }
-
-    if remainder != 0 {
-        let mut null_packed = 0;
-
-        for bit_idx in 0..remainder {
-            let i = split_off(&mut rows[bit_idx + chunks * 64], T::ENCODED_LEN);
-            let null = i[0] == 1;
-            null_count += !null as usize;
-            null_packed |= (null as u64) << bit_idx;
-
-            let value = T::Encoded::from_slice(&i[1..], options.descending);
-            values.push(T::decode(value));
-        }
-
-        nulls.push(null_packed);
+    for row in rows {
+        let i = split_off(row, T::ENCODED_LEN);
+        let value = T::Encoded::from_slice(&i[1..], options.descending);
+        values.append(T::decode(value));
     }
 
     let builder = ArrayDataBuilder::new(data_type)
-        .len(rows.len())
+        .len(len)
         .null_count(null_count)
-        .add_buffer(values.into())
-        .null_bit_buffer(Some(nulls.into()));
+        .add_buffer(values.finish())
+        .null_bit_buffer(Some(nulls));
 
     // SAFETY: Buffers correct length
     builder.build_unchecked()
@@ -333,7 +316,7 @@ pub fn decode_primitive<T: ArrowPrimitiveType>(
     options: SortOptions,
 ) -> PrimitiveArray<T>
 where
-    T::Native: FixedLengthEncoding + ToByteSlice,
+    T::Native: FixedLengthEncoding,
 {
     assert_eq!(
         std::mem::discriminant(&T::DATA_TYPE),
diff --git a/arrow/src/row/mod.rs b/arrow/src/row/mod.rs
index 6ce9f2b12..4f48b46cb 100644
--- a/arrow/src/row/mod.rs
+++ b/arrow/src/row/mod.rs
@@ -908,11 +908,22 @@ fn encode_column(
 }
 
 macro_rules! decode_primitive_helper {
-    ($t:ty, $rows: ident, $data_type:ident, $options:ident) => {
+    ($t:ty, $rows:ident, $data_type:ident, $options:ident) => {
         Arc::new(decode_primitive::<$t>($rows, $data_type, $options))
     };
 }
 
+macro_rules! decode_dictionary_helper {
+    ($t:ty, $interner:ident, $v:ident, $options:ident, $rows:ident) => {
+        Arc::new(decode_dictionary::<$t>(
+            $interner.unwrap(),
+            $v.as_ref(),
+            $options,
+            $rows,
+        )?)
+    };
+}
+
 /// Decodes a the provided `field` from `rows`
 ///
 /// # Safety
@@ -934,61 +945,9 @@ unsafe fn decode_column(
         DataType::LargeBinary => Arc::new(decode_binary::<i64>(rows, options)),
         DataType::Utf8 => Arc::new(decode_string::<i32>(rows, options, validate_utf8)),
         DataType::LargeUtf8 => Arc::new(decode_string::<i64>(rows, options, validate_utf8)),
-        DataType::Dictionary(k, v) => match k.as_ref() {
-            DataType::Int8 => Arc::new(decode_dictionary::<Int8Type>(
-                interner.unwrap(),
-                v.as_ref(),
-                options,
-                rows,
-            )?),
-            DataType::Int16 => Arc::new(decode_dictionary::<Int16Type>(
-                interner.unwrap(),
-                v.as_ref(),
-                options,
-                rows,
-            )?),
-            DataType::Int32 => Arc::new(decode_dictionary::<Int32Type>(
-                interner.unwrap(),
-                v.as_ref(),
-                options,
-                rows,
-            )?),
-            DataType::Int64 => Arc::new(decode_dictionary::<Int64Type>(
-                interner.unwrap(),
-                v.as_ref(),
-                options,
-                rows,
-            )?),
-            DataType::UInt8 => Arc::new(decode_dictionary::<UInt8Type>(
-                interner.unwrap(),
-                v.as_ref(),
-                options,
-                rows,
-            )?),
-            DataType::UInt16 => Arc::new(decode_dictionary::<UInt16Type>(
-                interner.unwrap(),
-                v.as_ref(),
-                options,
-                rows,
-            )?),
-            DataType::UInt32 => Arc::new(decode_dictionary::<UInt32Type>(
-                interner.unwrap(),
-                v.as_ref(),
-                options,
-                rows,
-            )?),
-            DataType::UInt64 => Arc::new(decode_dictionary::<UInt64Type>(
-                interner.unwrap(),
-                v.as_ref(),
-                options,
-                rows,
-            )?),
-            _ => {
-                return Err(ArrowError::InvalidArgumentError(format!(
-                    "{} is not a valid dictionary key type",
-                    field.data_type
-                )));
-            }
+        DataType::Dictionary(k, v) => downcast_integer! {
+            k.as_ref() => (decode_dictionary_helper, interner, v, options, rows),
+            _ => unreachable!()
         },
         _ => {
             return Err(ArrowError::NotYetImplemented(format!(