You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by li...@apache.org on 2023/01/11 02:19:21 UTC

[arrow-rs] branch master updated: Support decimal int32/64 for writer (#3431)

This is an automated email from the ASF dual-hosted git repository.

liukun pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git


The following commit(s) were added to refs/heads/master by this push:
     new ccb80e82b Support decimal int32/64 for writer (#3431)
ccb80e82b is described below

commit ccb80e82bbfd19e8b353e107ba41cfe0cbaa029a
Author: Kun Liu <li...@apache.org>
AuthorDate: Wed Jan 11 10:19:16 2023 +0800

    Support decimal int32/64 for writer (#3431)
---
 parquet/src/arrow/arrow_writer/mod.rs | 41 +++++++++++++++-----
 parquet/src/arrow/schema/mod.rs       | 73 ++++++++++++++++++++++++++++-------
 2 files changed, 90 insertions(+), 24 deletions(-)

diff --git a/parquet/src/arrow/arrow_writer/mod.rs b/parquet/src/arrow/arrow_writer/mod.rs
index 340ab246a..311981593 100644
--- a/parquet/src/arrow/arrow_writer/mod.rs
+++ b/parquet/src/arrow/arrow_writer/mod.rs
@@ -21,7 +21,9 @@ use std::collections::VecDeque;
 use std::io::Write;
 use std::sync::Arc;
 
-use arrow_array::{Array, ArrayRef, RecordBatch};
+use arrow_array::cast::as_primitive_array;
+use arrow_array::types::Decimal128Type;
+use arrow_array::{types, Array, ArrayRef, RecordBatch};
 use arrow_schema::{DataType as ArrowDataType, IntervalUnit, SchemaRef};
 
 use super::schema::{
@@ -397,6 +399,12 @@ fn write_leaf(
                     let array: &[i32] = data.buffers()[0].typed_data();
                     write_primitive(typed, &array[offset..offset + data.len()], levels)?
                 }
+                ArrowDataType::Decimal128(_, _) => {
+                    // use the int32 to represent the decimal with low precision
+                    let array = as_primitive_array::<Decimal128Type>(column)
+                        .unary::<_, types::Int32Type>(|v| v as i32);
+                    write_primitive(typed, array.values(), levels)?
+                }
                 _ => {
                     let array = arrow_cast::cast(column, &ArrowDataType::Int32)?;
                     let array = array
@@ -435,6 +443,12 @@ fn write_leaf(
                     let array: &[i64] = data.buffers()[0].typed_data();
                     write_primitive(typed, &array[offset..offset + data.len()], levels)?
                 }
+                ArrowDataType::Decimal128(_, _) => {
+                    // use the int64 to represent the decimal with low precision
+                    let array = as_primitive_array::<Decimal128Type>(column)
+                        .unary::<_, types::Int64Type>(|v| v as i64);
+                    write_primitive(typed, array.values(), levels)?
+                }
                 _ => {
                     let array = arrow_cast::cast(column, &ArrowDataType::Int64)?;
                     let array = array
@@ -840,23 +854,32 @@ mod tests {
         roundtrip(batch, Some(SMALL_SIZE / 2));
     }
 
-    #[test]
-    fn arrow_writer_decimal() {
-        let decimal_field = Field::new("a", DataType::Decimal128(5, 2), false);
+    fn get_decimal_batch(precision: u8, scale: i8) -> RecordBatch {
+        let decimal_field =
+            Field::new("a", DataType::Decimal128(precision, scale), false);
         let schema = Schema::new(vec![decimal_field]);
 
         let decimal_values = vec![10_000, 50_000, 0, -100]
             .into_iter()
             .map(Some)
             .collect::<Decimal128Array>()
-            .with_precision_and_scale(5, 2)
+            .with_precision_and_scale(precision, scale)
             .unwrap();
 
-        let batch =
-            RecordBatch::try_new(Arc::new(schema), vec![Arc::new(decimal_values)])
-                .unwrap();
+        RecordBatch::try_new(Arc::new(schema), vec![Arc::new(decimal_values)]).unwrap()
+    }
 
-        roundtrip(batch, Some(SMALL_SIZE / 2));
+    #[test]
+    fn arrow_writer_decimal() {
+        // int32 to store the decimal value
+        let batch_int32_decimal = get_decimal_batch(5, 2);
+        roundtrip(batch_int32_decimal, Some(SMALL_SIZE / 2));
+        // int64 to store the decimal value
+        let batch_int64_decimal = get_decimal_batch(12, 2);
+        roundtrip(batch_int64_decimal, Some(SMALL_SIZE / 2));
+        // fixed_length_byte_array to store the decimal value
+        let batch_fixed_len_byte_array_decimal = get_decimal_batch(30, 2);
+        roundtrip(batch_fixed_len_byte_array_decimal, Some(SMALL_SIZE / 2));
     }
 
     #[test]
diff --git a/parquet/src/arrow/schema/mod.rs b/parquet/src/arrow/schema/mod.rs
index 120612822..f03a6c695 100644
--- a/parquet/src/arrow/schema/mod.rs
+++ b/parquet/src/arrow/schema/mod.rs
@@ -399,21 +399,32 @@ fn arrow_to_parquet_type(field: &Field) -> Result<Type> {
                 .with_length(*length)
                 .build()
         }
-        DataType::Decimal128(precision, scale)
-        | DataType::Decimal256(precision, scale) => {
+        DataType::Decimal128(precision, scale) => {
             // Decimal precision determines the Parquet physical type to use.
-            // TODO(ARROW-12018): Enable the below after ARROW-10818 Decimal support
-            //
-            // let (physical_type, length) = if *precision > 1 && *precision <= 9 {
-            //     (PhysicalType::INT32, -1)
-            // } else if *precision <= 18 {
-            //     (PhysicalType::INT64, -1)
-            // } else {
-            //     (
-            //         PhysicalType::FIXED_LEN_BYTE_ARRAY,
-            //         decimal_length_from_precision(*precision) as i32,
-            //     )
-            // };
+            // Following the: https://github.com/apache/parquet-format/blob/master/LogicalTypes.md#decimal
+            let (physical_type, length) = if *precision > 1 && *precision <= 9 {
+                (PhysicalType::INT32, -1)
+            } else if *precision <= 18 {
+                (PhysicalType::INT64, -1)
+            } else {
+                (
+                    PhysicalType::FIXED_LEN_BYTE_ARRAY,
+                    decimal_length_from_precision(*precision) as i32,
+                )
+            };
+            Type::primitive_type_builder(name, physical_type)
+                .with_repetition(repetition)
+                .with_length(length)
+                .with_logical_type(Some(LogicalType::Decimal {
+                    scale: *scale as i32,
+                    precision: *precision as i32,
+                }))
+                .with_precision(*precision as i32)
+                .with_scale(*scale as i32)
+                .build()
+        }
+        DataType::Decimal256(precision, scale) => {
+            // For the decimal256, use the fixed length byte array to store the data
             Type::primitive_type_builder(name, PhysicalType::FIXED_LEN_BYTE_ARRAY)
                 .with_repetition(repetition)
                 .with_length(decimal_length_from_precision(*precision) as i32)
@@ -627,7 +638,7 @@ mod tests {
             ProjectionMask::all(),
             None,
         )
-        .unwrap();
+            .unwrap();
         assert_eq!(&arrow_fields, converted_arrow_schema.fields());
     }
 
@@ -1257,6 +1268,9 @@ mod tests {
             REPEATED INT32   int_list;
             REPEATED BINARY  byte_list;
             REPEATED BINARY  string_list (UTF8);
+            REQUIRED INT32 decimal_int32 (DECIMAL(8,2));
+            REQUIRED INT64 decimal_int64 (DECIMAL(16,2));
+            REQUIRED FIXED_LEN_BYTE_ARRAY (13) decimal_fix_length (DECIMAL(30,2));
         }
         ";
         let parquet_group_type = parse_message_type(message_type).unwrap();
@@ -1326,6 +1340,20 @@ mod tests {
                 ))),
                 false,
             ),
+            Field::new(
+                "decimal_int32",
+                DataType::Decimal128(8, 2),
+                false,
+            ),
+            Field::new(
+                "decimal_int64",
+                DataType::Decimal128(16, 2),
+                false,
+            ),
+            Field::new(
+                "decimal_fix_length",
+                DataType::Decimal128(30, 2),
+                false, ),
         ];
 
         assert_eq!(arrow_fields, converted_arrow_fields);
@@ -1373,6 +1401,9 @@ mod tests {
                 }
             }
             REQUIRED BINARY  dictionary_strings (STRING);
+            REQUIRED INT32 decimal_int32 (DECIMAL(8,2));
+            REQUIRED INT64 decimal_int64 (DECIMAL(16,2));
+            REQUIRED FIXED_LEN_BYTE_ARRAY (13) decimal_fix_length (DECIMAL(30,2));
         }
         ";
         let parquet_group_type = parse_message_type(message_type).unwrap();
@@ -1458,6 +1489,18 @@ mod tests {
                 DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)),
                 false,
             ),
+            Field::new(
+                "decimal_int32",
+                DataType::Decimal128(8, 2),
+                false),
+            Field::new("decimal_int64",
+                       DataType::Decimal128(16, 2),
+                       false),
+            Field::new(
+                "decimal_fix_length",
+                DataType::Decimal128(30, 2),
+                false,
+            ),
         ];
         let arrow_schema = Schema::new(arrow_fields);
         let converted_arrow_schema = arrow_to_parquet_schema(&arrow_schema).unwrap();