You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by li...@apache.org on 2023/01/11 02:19:21 UTC
[arrow-rs] branch master updated: Support decimal int32/64 for writer (#3431)
This is an automated email from the ASF dual-hosted git repository.
liukun pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git
The following commit(s) were added to refs/heads/master by this push:
new ccb80e82b Support decimal int32/64 for writer (#3431)
ccb80e82b is described below
commit ccb80e82bbfd19e8b353e107ba41cfe0cbaa029a
Author: Kun Liu <li...@apache.org>
AuthorDate: Wed Jan 11 10:19:16 2023 +0800
Support decimal int32/64 for writer (#3431)
---
parquet/src/arrow/arrow_writer/mod.rs | 41 +++++++++++++++-----
parquet/src/arrow/schema/mod.rs | 73 ++++++++++++++++++++++++++++-------
2 files changed, 90 insertions(+), 24 deletions(-)
diff --git a/parquet/src/arrow/arrow_writer/mod.rs b/parquet/src/arrow/arrow_writer/mod.rs
index 340ab246a..311981593 100644
--- a/parquet/src/arrow/arrow_writer/mod.rs
+++ b/parquet/src/arrow/arrow_writer/mod.rs
@@ -21,7 +21,9 @@ use std::collections::VecDeque;
use std::io::Write;
use std::sync::Arc;
-use arrow_array::{Array, ArrayRef, RecordBatch};
+use arrow_array::cast::as_primitive_array;
+use arrow_array::types::Decimal128Type;
+use arrow_array::{types, Array, ArrayRef, RecordBatch};
use arrow_schema::{DataType as ArrowDataType, IntervalUnit, SchemaRef};
use super::schema::{
@@ -397,6 +399,12 @@ fn write_leaf(
let array: &[i32] = data.buffers()[0].typed_data();
write_primitive(typed, &array[offset..offset + data.len()], levels)?
}
+ ArrowDataType::Decimal128(_, _) => {
+ // use the int32 to represent the decimal with low precision
+ let array = as_primitive_array::<Decimal128Type>(column)
+ .unary::<_, types::Int32Type>(|v| v as i32);
+ write_primitive(typed, array.values(), levels)?
+ }
_ => {
let array = arrow_cast::cast(column, &ArrowDataType::Int32)?;
let array = array
@@ -435,6 +443,12 @@ fn write_leaf(
let array: &[i64] = data.buffers()[0].typed_data();
write_primitive(typed, &array[offset..offset + data.len()], levels)?
}
+ ArrowDataType::Decimal128(_, _) => {
+ // use the int64 to represent the decimal with low precision
+ let array = as_primitive_array::<Decimal128Type>(column)
+ .unary::<_, types::Int64Type>(|v| v as i64);
+ write_primitive(typed, array.values(), levels)?
+ }
_ => {
let array = arrow_cast::cast(column, &ArrowDataType::Int64)?;
let array = array
@@ -840,23 +854,32 @@ mod tests {
roundtrip(batch, Some(SMALL_SIZE / 2));
}
- #[test]
- fn arrow_writer_decimal() {
- let decimal_field = Field::new("a", DataType::Decimal128(5, 2), false);
+ fn get_decimal_batch(precision: u8, scale: i8) -> RecordBatch {
+ let decimal_field =
+ Field::new("a", DataType::Decimal128(precision, scale), false);
let schema = Schema::new(vec![decimal_field]);
let decimal_values = vec![10_000, 50_000, 0, -100]
.into_iter()
.map(Some)
.collect::<Decimal128Array>()
- .with_precision_and_scale(5, 2)
+ .with_precision_and_scale(precision, scale)
.unwrap();
- let batch =
- RecordBatch::try_new(Arc::new(schema), vec![Arc::new(decimal_values)])
- .unwrap();
+ RecordBatch::try_new(Arc::new(schema), vec![Arc::new(decimal_values)]).unwrap()
+ }
- roundtrip(batch, Some(SMALL_SIZE / 2));
+ #[test]
+ fn arrow_writer_decimal() {
+ // int32 to store the decimal value
+ let batch_int32_decimal = get_decimal_batch(5, 2);
+ roundtrip(batch_int32_decimal, Some(SMALL_SIZE / 2));
+ // int64 to store the decimal value
+ let batch_int64_decimal = get_decimal_batch(12, 2);
+ roundtrip(batch_int64_decimal, Some(SMALL_SIZE / 2));
+ // fixed_length_byte_array to store the decimal value
+ let batch_fixed_len_byte_array_decimal = get_decimal_batch(30, 2);
+ roundtrip(batch_fixed_len_byte_array_decimal, Some(SMALL_SIZE / 2));
}
#[test]
diff --git a/parquet/src/arrow/schema/mod.rs b/parquet/src/arrow/schema/mod.rs
index 120612822..f03a6c695 100644
--- a/parquet/src/arrow/schema/mod.rs
+++ b/parquet/src/arrow/schema/mod.rs
@@ -399,21 +399,32 @@ fn arrow_to_parquet_type(field: &Field) -> Result<Type> {
.with_length(*length)
.build()
}
- DataType::Decimal128(precision, scale)
- | DataType::Decimal256(precision, scale) => {
+ DataType::Decimal128(precision, scale) => {
// Decimal precision determines the Parquet physical type to use.
- // TODO(ARROW-12018): Enable the below after ARROW-10818 Decimal support
- //
- // let (physical_type, length) = if *precision > 1 && *precision <= 9 {
- // (PhysicalType::INT32, -1)
- // } else if *precision <= 18 {
- // (PhysicalType::INT64, -1)
- // } else {
- // (
- // PhysicalType::FIXED_LEN_BYTE_ARRAY,
- // decimal_length_from_precision(*precision) as i32,
- // )
- // };
+ // Following the: https://github.com/apache/parquet-format/blob/master/LogicalTypes.md#decimal
+ let (physical_type, length) = if *precision > 1 && *precision <= 9 {
+ (PhysicalType::INT32, -1)
+ } else if *precision <= 18 {
+ (PhysicalType::INT64, -1)
+ } else {
+ (
+ PhysicalType::FIXED_LEN_BYTE_ARRAY,
+ decimal_length_from_precision(*precision) as i32,
+ )
+ };
+ Type::primitive_type_builder(name, physical_type)
+ .with_repetition(repetition)
+ .with_length(length)
+ .with_logical_type(Some(LogicalType::Decimal {
+ scale: *scale as i32,
+ precision: *precision as i32,
+ }))
+ .with_precision(*precision as i32)
+ .with_scale(*scale as i32)
+ .build()
+ }
+ DataType::Decimal256(precision, scale) => {
+ // For the decimal256, use the fixed length byte array to store the data
Type::primitive_type_builder(name, PhysicalType::FIXED_LEN_BYTE_ARRAY)
.with_repetition(repetition)
.with_length(decimal_length_from_precision(*precision) as i32)
@@ -627,7 +638,7 @@ mod tests {
ProjectionMask::all(),
None,
)
- .unwrap();
+ .unwrap();
assert_eq!(&arrow_fields, converted_arrow_schema.fields());
}
@@ -1257,6 +1268,9 @@ mod tests {
REPEATED INT32 int_list;
REPEATED BINARY byte_list;
REPEATED BINARY string_list (UTF8);
+ REQUIRED INT32 decimal_int32 (DECIMAL(8,2));
+ REQUIRED INT64 decimal_int64 (DECIMAL(16,2));
+ REQUIRED FIXED_LEN_BYTE_ARRAY (13) decimal_fix_length (DECIMAL(30,2));
}
";
let parquet_group_type = parse_message_type(message_type).unwrap();
@@ -1326,6 +1340,20 @@ mod tests {
))),
false,
),
+ Field::new(
+ "decimal_int32",
+ DataType::Decimal128(8, 2),
+ false,
+ ),
+ Field::new(
+ "decimal_int64",
+ DataType::Decimal128(16, 2),
+ false,
+ ),
+ Field::new(
+ "decimal_fix_length",
+ DataType::Decimal128(30, 2),
+ false, ),
];
assert_eq!(arrow_fields, converted_arrow_fields);
@@ -1373,6 +1401,9 @@ mod tests {
}
}
REQUIRED BINARY dictionary_strings (STRING);
+ REQUIRED INT32 decimal_int32 (DECIMAL(8,2));
+ REQUIRED INT64 decimal_int64 (DECIMAL(16,2));
+ REQUIRED FIXED_LEN_BYTE_ARRAY (13) decimal_fix_length (DECIMAL(30,2));
}
";
let parquet_group_type = parse_message_type(message_type).unwrap();
@@ -1458,6 +1489,18 @@ mod tests {
DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)),
false,
),
+ Field::new(
+ "decimal_int32",
+ DataType::Decimal128(8, 2),
+ false),
+ Field::new("decimal_int64",
+ DataType::Decimal128(16, 2),
+ false),
+ Field::new(
+ "decimal_fix_length",
+ DataType::Decimal128(30, 2),
+ false,
+ ),
];
let arrow_schema = Schema::new(arrow_fields);
let converted_arrow_schema = arrow_to_parquet_schema(&arrow_schema).unwrap();