You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by tu...@apache.org on 2022/11/09 17:38:21 UTC

[arrow-rs] branch master updated: Fix row format decode loses timezone (#3063) (#3064)

This is an automated email from the ASF dual-hosted git repository.

tustvold pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git


The following commit(s) were added to refs/heads/master by this push:
     new 5a3ecc2ea Fix row format decode loses timezone (#3063) (#3064)
5a3ecc2ea is described below

commit 5a3ecc2ea270af7f9aba4c1a162072acf9541fb8
Author: Raphael Taylor-Davies <17...@users.noreply.github.com>
AuthorDate: Thu Nov 10 06:38:16 2022 +1300

    Fix row format decode loses timezone (#3063) (#3064)
---
 arrow/src/row/dictionary.rs | 36 +++++++++++++++-----------
 arrow/src/row/fixed.rs      | 17 ++++++++++---
 arrow/src/row/mod.rs        | 61 +++++++++++++++++++++++++++++++++++----------
 3 files changed, 83 insertions(+), 31 deletions(-)

diff --git a/arrow/src/row/dictionary.rs b/arrow/src/row/dictionary.rs
index 1ec7c2a21..950a7d897 100644
--- a/arrow/src/row/dictionary.rs
+++ b/arrow/src/row/dictionary.rs
@@ -90,8 +90,8 @@ pub fn encode_dictionary<K: ArrowDictionaryKeyType>(
 }
 
 macro_rules! decode_primitive_helper {
-    ($t:ty, $values: ident) => {
-        decode_primitive::<$t>(&$values)
+    ($t:ty, $values: ident, $data_type:ident) => {
+        decode_primitive::<$t>(&$values, $data_type.clone())
     };
 }
 
@@ -170,11 +170,11 @@ pub unsafe fn decode_dictionary<K: ArrowDictionaryKeyType>(
     }
 
     let child = downcast_primitive! {
-        &value_type => (decode_primitive_helper, values),
+        value_type => (decode_primitive_helper, values, value_type),
         DataType::Null => NullArray::new(values.len()).into_data(),
         DataType::Boolean => decode_bool(&values),
-        DataType::Decimal128(p, s) => decode_decimal::<Decimal128Type>(&values, *p, *s),
-        DataType::Decimal256(p, s) => decode_decimal::<Decimal256Type>(&values, *p, *s),
+        DataType::Decimal128(_, _) => decode_primitive_helper!(Decimal128Type, values, value_type),
+        DataType::Decimal256(_, _) => decode_primitive_helper!(Decimal256Type, values, value_type),
         DataType::Utf8 => decode_string::<i32>(&values),
         DataType::LargeUtf8 => decode_string::<i64>(&values),
         DataType::Binary => decode_binary::<i32>(&values),
@@ -247,7 +247,11 @@ fn decode_bool(values: &[&[u8]]) -> ArrayData {
 }
 
 /// Decodes a fixed length type array from dictionary values
-fn decode_fixed<T: FixedLengthEncoding + ToByteSlice>(
+///
+/// # Safety
+///
+/// `data_type` must be appropriate native type for `T`
+unsafe fn decode_fixed<T: FixedLengthEncoding + ToByteSlice>(
     values: &[&[u8]],
     data_type: DataType,
 ) -> ArrayData {
@@ -267,17 +271,19 @@ fn decode_fixed<T: FixedLengthEncoding + ToByteSlice>(
 }
 
 /// Decodes a `PrimitiveArray` from dictionary values
-fn decode_primitive<T: ArrowPrimitiveType>(values: &[&[u8]]) -> ArrayData
+fn decode_primitive<T: ArrowPrimitiveType>(
+    values: &[&[u8]],
+    data_type: DataType,
+) -> ArrayData
 where
     T::Native: FixedLengthEncoding,
 {
-    decode_fixed::<T::Native>(values, T::DATA_TYPE)
-}
+    assert_eq!(
+        std::mem::discriminant(&T::DATA_TYPE),
+        std::mem::discriminant(&data_type),
+    );
 
-/// Decodes a `DecimalArray` from dictionary values
-fn decode_decimal<T: DecimalType>(values: &[&[u8]], precision: u8, scale: u8) -> ArrayData
-where
-    T::Native: FixedLengthEncoding,
-{
-    decode_fixed::<T::Native>(values, T::TYPE_CONSTRUCTOR(precision, scale))
+    // SAFETY:
+    // Validated data type above
+    unsafe { decode_fixed::<T::Native>(values, data_type) }
 }
diff --git a/arrow/src/row/fixed.rs b/arrow/src/row/fixed.rs
index d5935cfb6..76bf358e7 100644
--- a/arrow/src/row/fixed.rs
+++ b/arrow/src/row/fixed.rs
@@ -267,7 +267,11 @@ pub fn decode_bool(rows: &mut [&[u8]], options: SortOptions) -> BooleanArray {
 }
 
 /// Decodes a `ArrayData` from rows based on the provided `FixedLengthEncoding` `T`
-fn decode_fixed<T: FixedLengthEncoding + ToByteSlice>(
+///
+/// # Safety
+///
+/// `data_type` must be appropriate native type for `T`
+unsafe fn decode_fixed<T: FixedLengthEncoding + ToByteSlice>(
     rows: &mut [&[u8]],
     data_type: DataType,
     options: SortOptions,
@@ -319,16 +323,23 @@ fn decode_fixed<T: FixedLengthEncoding + ToByteSlice>(
         .null_bit_buffer(Some(nulls.into()));
 
     // SAFETY: Buffers correct length
-    unsafe { builder.build_unchecked() }
+    builder.build_unchecked()
 }
 
 /// Decodes a `PrimitiveArray` from rows
 pub fn decode_primitive<T: ArrowPrimitiveType>(
     rows: &mut [&[u8]],
+    data_type: DataType,
     options: SortOptions,
 ) -> PrimitiveArray<T>
 where
     T::Native: FixedLengthEncoding + ToByteSlice,
 {
-    decode_fixed::<T::Native>(rows, T::DATA_TYPE, options).into()
+    assert_eq!(
+        std::mem::discriminant(&T::DATA_TYPE),
+        std::mem::discriminant(&data_type),
+    );
+    // SAFETY:
+    // Validated data type above
+    unsafe { decode_fixed::<T::Native>(rows, data_type, options).into() }
 }
diff --git a/arrow/src/row/mod.rs b/arrow/src/row/mod.rs
index 8af642240..15fe5dc42 100644
--- a/arrow/src/row/mod.rs
+++ b/arrow/src/row/mod.rs
@@ -629,8 +629,8 @@ fn encode_column(
 }
 
 macro_rules! decode_primitive_helper {
-    ($t:ty, $rows: ident, $options:ident) => {
-        Arc::new(decode_primitive::<$t>($rows, $options))
+    ($t:ty, $rows: ident, $data_type:ident, $options:ident) => {
+        Arc::new(decode_primitive::<$t>($rows, $data_type, $options))
     };
 }
 
@@ -645,24 +645,17 @@ unsafe fn decode_column(
     interner: Option<&OrderPreservingInterner>,
 ) -> Result<ArrayRef> {
     let options = field.options;
+    let data_type = field.data_type.clone();
     let array: ArrayRef = downcast_primitive! {
-        &field.data_type => (decode_primitive_helper, rows, options),
+        data_type => (decode_primitive_helper, rows, data_type, options),
         DataType::Null => Arc::new(NullArray::new(rows.len())),
         DataType::Boolean => Arc::new(decode_bool(rows, options)),
         DataType::Binary => Arc::new(decode_binary::<i32>(rows, options)),
         DataType::LargeBinary => Arc::new(decode_binary::<i64>(rows, options)),
         DataType::Utf8 => Arc::new(decode_string::<i32>(rows, options)),
         DataType::LargeUtf8 => Arc::new(decode_string::<i64>(rows, options)),
-        DataType::Decimal128(p, s) => Arc::new(
-            decode_primitive::<Decimal128Type>(rows, options)
-                .with_precision_and_scale(*p, *s)
-                .unwrap(),
-        ),
-        DataType::Decimal256(p, s) => Arc::new(
-            decode_primitive::<Decimal256Type>(rows, options)
-                .with_precision_and_scale(*p, *s)
-                .unwrap(),
-        ),
+        DataType::Decimal128(_, _) => decode_primitive_helper!(Decimal128Type, rows, data_type, options),
+        DataType::Decimal256(_, _) => decode_primitive_helper!(Decimal256Type, rows, data_type, options),
         DataType::Dictionary(k, v) => match k.as_ref() {
             DataType::Int8 => Arc::new(decode_dictionary::<Int8Type>(
                 interner.unwrap(),
@@ -900,6 +893,48 @@ mod tests {
         assert_eq!(&cols[0], &col);
     }
 
+    #[test]
+    fn test_timezone() {
+        let a = TimestampNanosecondArray::from(vec![1, 2, 3, 4, 5])
+            .with_timezone("+01:00".to_string());
+        let d = a.data_type().clone();
+
+        let mut converter =
+            RowConverter::new(vec![SortField::new(a.data_type().clone())]);
+        let rows = converter.convert_columns(&[Arc::new(a) as _]).unwrap();
+        let back = converter.convert_rows(&rows).unwrap();
+        assert_eq!(back.len(), 1);
+        assert_eq!(back[0].data_type(), &d);
+
+        // Test dictionary
+        let mut a =
+            PrimitiveDictionaryBuilder::<Int32Type, TimestampNanosecondType>::new();
+        a.append(34).unwrap();
+        a.append_null();
+        a.append(345).unwrap();
+
+        // Construct dictionary with a timezone
+        let dict = a.finish();
+        let values = TimestampNanosecondArray::from(dict.values().data().clone());
+        let dict_with_tz = dict.with_values(&values.with_timezone("+02:00".to_string()));
+        let d = DataType::Dictionary(
+            Box::new(DataType::Int32),
+            Box::new(DataType::Timestamp(
+                TimeUnit::Nanosecond,
+                Some("+02:00".to_string()),
+            )),
+        );
+
+        assert_eq!(dict_with_tz.data_type(), &d);
+        let mut converter = RowConverter::new(vec![SortField::new(d.clone())]);
+        let rows = converter
+            .convert_columns(&[Arc::new(dict_with_tz) as _])
+            .unwrap();
+        let back = converter.convert_rows(&rows).unwrap();
+        assert_eq!(back.len(), 1);
+        assert_eq!(back[0].data_type(), &d);
+    }
+
     #[test]
     fn test_null_encoding() {
         let col = Arc::new(NullArray::new(10));