You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by tu...@apache.org on 2023/06/02 13:29:06 UTC

[arrow-rs] branch master updated: Add roundtrip tests for Decimal256 and fix issues (#4264) (#4311)

This is an automated email from the ASF dual-hosted git repository.

tustvold pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git


The following commit(s) were added to refs/heads/master by this push:
     new a7164849c Add roundtrip tests for Decimal256 and fix issues (#4264) (#4311)
a7164849c is described below

commit a7164849c56be041fc9ade8f9a55efac40e91f99
Author: Raphael Taylor-Davies <17...@users.noreply.github.com>
AuthorDate: Fri Jun 2 14:29:00 2023 +0100

    Add roundtrip tests for Decimal256 and fix issues (#4264) (#4311)
    
    * Add roundtrip tests for Decimal256 and fix issues (#4264)
    
    * Review feedback
---
 parquet/src/arrow/arrow_reader/mod.rs | 58 ++++++++++++++++++++++++++++++++++-
 parquet/src/arrow/schema/mod.rs       | 16 ++--------
 parquet/src/arrow/schema/primitive.rs | 17 ++++++++--
 3 files changed, 73 insertions(+), 18 deletions(-)

diff --git a/parquet/src/arrow/arrow_reader/mod.rs b/parquet/src/arrow/arrow_reader/mod.rs
index deca0c719..432b00399 100644
--- a/parquet/src/arrow/arrow_reader/mod.rs
+++ b/parquet/src/arrow/arrow_reader/mod.rs
@@ -543,13 +543,15 @@ mod tests {
     use std::sync::Arc;
 
     use bytes::Bytes;
+    use num::PrimInt;
     use rand::{thread_rng, Rng, RngCore};
     use tempfile::tempfile;
 
     use arrow_array::builder::*;
+    use arrow_array::types::{Decimal128Type, Decimal256Type, DecimalType};
     use arrow_array::*;
     use arrow_array::{RecordBatch, RecordBatchReader};
-    use arrow_buffer::{i256, Buffer};
+    use arrow_buffer::{i256, ArrowNativeType, Buffer};
     use arrow_data::ArrayDataBuilder;
     use arrow_schema::{DataType as ArrowDataType, Field, Fields, Schema};
 
@@ -2554,4 +2556,58 @@ mod tests {
         assert_eq!(out.num_rows(), 1);
         assert_eq!(out, batch.slice(2, 1));
     }
+
+    fn test_decimal_roundtrip<T: DecimalType>() {
+        // Precision <= 9 -> INT32
+        // Precision <= 18 -> INT64
+        // Precision > 18 -> FIXED_LEN_BYTE_ARRAY
+
+        let d = |values: Vec<usize>, p: u8| {
+            let iter = values.into_iter().map(T::Native::usize_as);
+            PrimitiveArray::<T>::from_iter_values(iter)
+                .with_precision_and_scale(p, 2)
+                .unwrap()
+        };
+
+        let d1 = d(vec![1, 2, 3, 4, 5], 9);
+        let d2 = d(vec![1, 2, 3, 4, 10.pow(10) - 1], 10);
+        let d3 = d(vec![1, 2, 3, 4, 10.pow(18) - 1], 18);
+        let d4 = d(vec![1, 2, 3, 4, 10.pow(19) - 1], 19);
+
+        let batch = RecordBatch::try_from_iter([
+            ("d1", Arc::new(d1) as ArrayRef),
+            ("d2", Arc::new(d2) as ArrayRef),
+            ("d3", Arc::new(d3) as ArrayRef),
+            ("d4", Arc::new(d4) as ArrayRef),
+        ])
+        .unwrap();
+
+        let mut buffer = Vec::with_capacity(1024);
+        let mut writer = ArrowWriter::try_new(&mut buffer, batch.schema(), None).unwrap();
+        writer.write(&batch).unwrap();
+        writer.close().unwrap();
+
+        let builder =
+            ParquetRecordBatchReaderBuilder::try_new(Bytes::from(buffer)).unwrap();
+        let t1 = builder.parquet_schema().columns()[0].physical_type();
+        assert_eq!(t1, PhysicalType::INT32);
+        let t2 = builder.parquet_schema().columns()[1].physical_type();
+        assert_eq!(t2, PhysicalType::INT64);
+        let t3 = builder.parquet_schema().columns()[2].physical_type();
+        assert_eq!(t3, PhysicalType::INT64);
+        let t4 = builder.parquet_schema().columns()[3].physical_type();
+        assert_eq!(t4, PhysicalType::FIXED_LEN_BYTE_ARRAY);
+
+        let mut reader = builder.build().unwrap();
+        assert_eq!(batch.schema(), reader.schema());
+
+        let out = reader.next().unwrap().unwrap();
+        assert_eq!(batch, out);
+    }
+
+    #[test]
+    fn test_decimal() {
+        test_decimal_roundtrip::<Decimal128Type>();
+        test_decimal_roundtrip::<Decimal256Type>();
+    }
 }
diff --git a/parquet/src/arrow/schema/mod.rs b/parquet/src/arrow/schema/mod.rs
index 3b9691044..7469d86dc 100644
--- a/parquet/src/arrow/schema/mod.rs
+++ b/parquet/src/arrow/schema/mod.rs
@@ -443,7 +443,8 @@ fn arrow_to_parquet_type(field: &Field) -> Result<Type> {
                 .with_length(*length)
                 .build()
         }
-        DataType::Decimal128(precision, scale) => {
+        DataType::Decimal128(precision, scale)
+        | DataType::Decimal256(precision, scale) => {
             // Decimal precision determines the Parquet physical type to use.
             // Following the: https://github.com/apache/parquet-format/blob/master/LogicalTypes.md#decimal
             let (physical_type, length) = if *precision > 1 && *precision <= 9 {
@@ -467,19 +468,6 @@ fn arrow_to_parquet_type(field: &Field) -> Result<Type> {
                 .with_scale(*scale as i32)
                 .build()
         }
-        DataType::Decimal256(precision, scale) => {
-            // For the decimal256, use the fixed length byte array to store the data
-            Type::primitive_type_builder(name, PhysicalType::FIXED_LEN_BYTE_ARRAY)
-                .with_repetition(repetition)
-                .with_length(decimal_length_from_precision(*precision) as i32)
-                .with_logical_type(Some(LogicalType::Decimal {
-                    scale: *scale as i32,
-                    precision: *precision as i32,
-                }))
-                .with_precision(*precision as i32)
-                .with_scale(*scale as i32)
-                .build()
-        }
         DataType::Utf8 | DataType::LargeUtf8 => {
             Type::primitive_type_builder(name, PhysicalType::BYTE_ARRAY)
                 .with_logical_type(Some(LogicalType::String))
diff --git a/parquet/src/arrow/schema/primitive.rs b/parquet/src/arrow/schema/primitive.rs
index 62133f157..c67f78076 100644
--- a/parquet/src/arrow/schema/primitive.rs
+++ b/parquet/src/arrow/schema/primitive.rs
@@ -20,7 +20,7 @@ use crate::basic::{
 };
 use crate::errors::{ParquetError, Result};
 use crate::schema::types::{BasicTypeInfo, Type};
-use arrow_schema::{DataType, IntervalUnit, TimeUnit};
+use arrow_schema::{DataType, IntervalUnit, TimeUnit, DECIMAL128_MAX_PRECISION};
 
 /// Converts [`Type`] to [`DataType`] with an optional `arrow_type_hint`
 /// provided by the arrow schema
@@ -62,6 +62,9 @@ fn apply_hint(parquet: DataType, hint: DataType) -> DataType {
         // Determine interval time unit (#1666)
         (DataType::Interval(_), DataType::Interval(_)) => hint,
 
+        // Promote to Decimal256
+        (DataType::Decimal128(_, _), DataType::Decimal256(_, _)) => hint,
+
         // Potentially preserve dictionary encoding
         (_, DataType::Dictionary(_, value)) => {
             // Apply hint to inner type
@@ -103,6 +106,14 @@ fn from_parquet(parquet_type: &Type) -> Result<DataType> {
     }
 }
 
+fn decimal_type(scale: i32, precision: i32) -> Result<DataType> {
+    if precision <= DECIMAL128_MAX_PRECISION as _ {
+        decimal_128_type(scale, precision)
+    } else {
+        decimal_256_type(scale, precision)
+    }
+}
+
 fn decimal_128_type(scale: i32, precision: i32) -> Result<DataType> {
     let scale = scale
         .try_into()
@@ -255,8 +266,8 @@ fn from_byte_array(info: &BasicTypeInfo, precision: i32, scale: i32) -> Result<D
                 precision: p,
             }),
             _,
-        ) => decimal_128_type(s, p),
-        (None, ConvertedType::DECIMAL) => decimal_128_type(scale, precision),
+        ) => decimal_type(s, p),
+        (None, ConvertedType::DECIMAL) => decimal_type(scale, precision),
         (logical, converted) => Err(arrow_err!(
             "Unable to convert parquet BYTE_ARRAY logical type {:?} or converted type {}",
             logical,