You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by al...@apache.org on 2021/07/14 11:51:21 UTC
[arrow-rs] branch master updated: generate parquet schema from rust
struct (#539)
This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git
The following commit(s) were added to refs/heads/master by this push:
new 55a5863 generate parquet schema from rust struct (#539)
55a5863 is described below
commit 55a58634ff9f1090cc9c2db770f84e3502c38b34
Author: Wakahisa <ne...@gmail.com>
AuthorDate: Wed Jul 14 13:51:17 2021 +0200
generate parquet schema from rust struct (#539)
* generate parquet schema from rust struct
* support all primitive types through logical types
---
parquet/src/record/record_writer.rs | 5 ++
parquet_derive/src/lib.rs | 31 +++++--
parquet_derive/src/parquet_field.rs | 157 ++++++++++++++++++++++++++++++++++--
parquet_derive_test/src/lib.rs | 84 ++++++++++++-------
4 files changed, 236 insertions(+), 41 deletions(-)
diff --git a/parquet/src/record/record_writer.rs b/parquet/src/record/record_writer.rs
index 56817eb..6668eec 100644
--- a/parquet/src/record/record_writer.rs
+++ b/parquet/src/record/record_writer.rs
@@ -15,6 +15,8 @@
// specific language governing permissions and limitations
// under the License.
+use crate::schema::types::TypePtr;
+
use super::super::errors::ParquetError;
use super::super::file::writer::RowGroupWriter;
@@ -23,4 +25,7 @@ pub trait RecordWriter<T> {
&self,
row_group_writer: &mut Box<dyn RowGroupWriter>,
) -> Result<(), ParquetError>;
+
+ /// Generated schema
+ fn schema(&self) -> Result<TypePtr, ParquetError>;
}
diff --git a/parquet_derive/src/lib.rs b/parquet_derive/src/lib.rs
index 279d0f7..1c53227 100644
--- a/parquet_derive/src/lib.rs
+++ b/parquet_derive/src/lib.rs
@@ -52,11 +52,6 @@ mod parquet_field;
/// pub a_str: &'a str,
/// }
///
-/// let schema_str = "message schema {
-/// REQUIRED boolean a_bool;
-/// REQUIRED BINARY a_str (UTF8);
-/// }";
-///
/// pub fn write_some_records() {
/// let samples = vec![
/// ACompleteRecord {
@@ -69,7 +64,7 @@ mod parquet_field;
/// }
/// ];
///
-/// let schema = Arc::new(parse_message_type(schema_str).unwrap());
+/// let schema = samples.as_slice().schema();
///
/// let props = Arc::new(WriterProperties::builder().build());
/// let mut writer = SerializedFileWriter::new(file, schema, props).unwrap();
@@ -101,9 +96,15 @@ pub fn parquet_record_writer(input: proc_macro::TokenStream) -> proc_macro::Toke
let derived_for = input.ident;
let generics = input.generics;
+ let field_types: Vec<proc_macro2::TokenStream> =
+ field_infos.iter().map(|x| x.parquet_type()).collect();
+
(quote! {
impl#generics RecordWriter<#derived_for#generics> for &[#derived_for#generics] {
- fn write_to_row_group(&self, row_group_writer: &mut Box<parquet::file::writer::RowGroupWriter>) -> Result<(), parquet::errors::ParquetError> {
+ fn write_to_row_group(
+ &self,
+ row_group_writer: &mut Box<parquet::file::writer::RowGroupWriter>
+ ) -> Result<(), parquet::errors::ParquetError> {
let mut row_group_writer = row_group_writer;
let records = &self; // Used by all the writer snippets to be more clear
@@ -121,6 +122,22 @@ pub fn parquet_record_writer(input: proc_macro::TokenStream) -> proc_macro::Toke
Ok(())
}
+
+ fn schema(&self) -> Result<parquet::schema::types::TypePtr, parquet::errors::ParquetError> {
+ use parquet::schema::types::Type as ParquetType;
+ use parquet::schema::types::TypePtr;
+ use parquet::basic::LogicalType;
+ use parquet::basic::*;
+
+ let mut fields: Vec<TypePtr> = Vec::new();
+ #(
+ #field_types
+ );*;
+ let group = parquet::schema::types::Type::group_type_builder("rust_schema")
+ .with_fields(&mut fields)
+ .build()?;
+ Ok(group.into())
+ }
}
}).into()
}
diff --git a/parquet_derive/src/parquet_field.rs b/parquet_derive/src/parquet_field.rs
index 328f4a6..6f2fa0c 100644
--- a/parquet_derive/src/parquet_field.rs
+++ b/parquet_derive/src/parquet_field.rs
@@ -174,6 +174,50 @@ impl Field {
}
}
+ pub fn parquet_type(&self) -> proc_macro2::TokenStream {
+ // TODO: Support group types
+ // TODO: Add length if dealing with fixedlenbinary
+
+ let field_name = &self.ident.to_string();
+ let physical_type = match self.ty.physical_type() {
+ parquet::basic::Type::BOOLEAN => quote! {
+ parquet::basic::Type::BOOLEAN
+ },
+ parquet::basic::Type::INT32 => quote! {
+ parquet::basic::Type::INT32
+ },
+ parquet::basic::Type::INT64 => quote! {
+ parquet::basic::Type::INT64
+ },
+ parquet::basic::Type::INT96 => quote! {
+ parquet::basic::Type::INT96
+ },
+ parquet::basic::Type::FLOAT => quote! {
+ parquet::basic::Type::FLOAT
+ },
+ parquet::basic::Type::DOUBLE => quote! {
+ parquet::basic::Type::DOUBLE
+ },
+ parquet::basic::Type::BYTE_ARRAY => quote! {
+ parquet::basic::Type::BYTE_ARRAY
+ },
+ parquet::basic::Type::FIXED_LEN_BYTE_ARRAY => quote! {
+ parquet::basic::Type::FIXED_LEN_BYTE_ARRAY
+ },
+ };
+ let logical_type = self.ty.logical_type();
+ let repetition = self.ty.repetition();
+ quote! {
+ fields.push(ParquetType::primitive_type_builder(#field_name, #physical_type)
+ .with_logical_type(#logical_type)
+ .with_repetition(#repetition)
+ .build()
+ .unwrap()
+ .into()
+ );
+ }
+ }
+
fn option_into_vals(&self) -> proc_macro2::TokenStream {
let field_name = &self.ident;
let is_a_byte_buf = self.is_a_byte_buf;
@@ -201,7 +245,12 @@ impl Field {
} else if is_a_byte_buf {
quote! { Some((&inner[..]).into())}
} else {
- quote! { Some(inner) }
+ // Type might need converting to a physical type
+ match self.ty.physical_type() {
+ parquet::basic::Type::INT32 => quote! { Some(inner as i32) },
+ parquet::basic::Type::INT64 => quote! { Some(inner as i64) },
+ _ => quote! { Some(inner) },
+ }
};
quote! {
@@ -232,7 +281,12 @@ impl Field {
} else if is_a_byte_buf {
quote! { (&rec.#field_name[..]).into() }
} else {
- quote! { rec.#field_name }
+ // Type might need converting to a physical type
+ match self.ty.physical_type() {
+ parquet::basic::Type::INT32 => quote! { rec.#field_name as i32 },
+ parquet::basic::Type::INT64 => quote! { rec.#field_name as i64 },
+ _ => quote! { rec.#field_name },
+ }
};
quote! {
@@ -403,7 +457,14 @@ impl Type {
"bool" => BasicType::BOOLEAN,
"u8" | "u16" | "u32" => BasicType::INT32,
"i8" | "i16" | "i32" | "NaiveDate" => BasicType::INT32,
- "u64" | "i64" | "usize" | "NaiveDateTime" => BasicType::INT64,
+ "u64" | "i64" | "NaiveDateTime" => BasicType::INT64,
+ "usize" | "isize" => {
+ if usize::BITS == 64 {
+ BasicType::INT64
+ } else {
+ BasicType::INT32
+ }
+ }
"f32" => BasicType::FLOAT,
"f64" => BasicType::DOUBLE,
"String" | "str" | "Uuid" => BasicType::BYTE_ARRAY,
@@ -411,6 +472,83 @@ impl Type {
}
}
+ fn logical_type(&self) -> proc_macro2::TokenStream {
+ let last_part = self.last_part();
+ let leaf_type = self.leaf_type_recursive();
+
+ match leaf_type {
+ Type::Array(ref first_type) => {
+ if let Type::TypePath(_) = **first_type {
+ if last_part == "u8" {
+ return quote! { None };
+ }
+ }
+ }
+ Type::Vec(ref first_type) => {
+ if let Type::TypePath(_) = **first_type {
+ if last_part == "u8" {
+ return quote! { None };
+ }
+ }
+ }
+ _ => (),
+ }
+
+ match last_part.trim() {
+ "bool" => quote! { None },
+ "u8" => quote! { Some(LogicalType::INTEGER(IntType {
+ bit_width: 8,
+ is_signed: false,
+ })) },
+ "u16" => quote! { Some(LogicalType::INTEGER(IntType {
+ bit_width: 16,
+ is_signed: false,
+ })) },
+ "u32" => quote! { Some(LogicalType::INTEGER(IntType {
+ bit_width: 32,
+ is_signed: false,
+ })) },
+ "u64" => quote! { Some(LogicalType::INTEGER(IntType {
+ bit_width: 64,
+ is_signed: false,
+ })) },
+ "i8" => quote! { Some(LogicalType::INTEGER(IntType {
+ bit_width: 8,
+ is_signed: true,
+ })) },
+ "i16" => quote! { Some(LogicalType::INTEGER(IntType {
+ bit_width: 16,
+ is_signed: true,
+ })) },
+ "i32" | "i64" => quote! { None },
+ "usize" => {
+ quote! { Some(LogicalType::INTEGER(IntType {
+ bit_width: usize::BITS as i8,
+ is_signed: false
+ })) }
+ }
+ "isize" => {
+ quote! { Some(LogicalType::INTEGER(IntType {
+ bit_width: usize::BITS as i8,
+ is_signed: true
+ })) }
+ }
+ "NaiveDate" => quote! { Some(LogicalType::DATE(Default::default())) },
+ "f32" | "f64" => quote! { None },
+ "String" | "str" => quote! { Some(LogicalType::STRING(Default::default())) },
+ "Uuid" => quote! { Some(LogicalType::UUID(Default::default())) },
+ f => unimplemented!("{} currently is not supported", f),
+ }
+ }
+
+ fn repetition(&self) -> proc_macro2::TokenStream {
+ match &self {
+ Type::Option(_) => quote! { Repetition::OPTIONAL },
+ Type::Reference(_, ty) => ty.repetition(),
+ _ => quote! { Repetition::REQUIRED },
+ }
+ }
+
/// Convert a parsed rust field AST in to a more easy to manipulate
/// parquet_derive::Field
fn from(f: &syn::Field) -> Self {
@@ -505,7 +643,7 @@ mod test {
assert_eq!(snippet,
(quote!{
{
- let vals : Vec < _ > = records . iter ( ) . map ( | rec | rec . counter ) . collect ( );
+ let vals : Vec < _ > = records . iter ( ) . map ( | rec | rec . counter as i64 ) . collect ( );
if let parquet::column::writer::ColumnWriter::Int64ColumnWriter ( ref mut typed ) = column_writer {
typed . write_batch ( & vals [ .. ] , None , None ) ?;
@@ -585,7 +723,7 @@ mod test {
let vals: Vec <_> = records.iter().filter_map( |rec| {
if let Some ( inner ) = rec . optional_dumb_int {
- Some ( inner )
+ Some ( inner as i32 )
} else {
None
}
@@ -636,12 +774,13 @@ mod test {
struct ABasicStruct {
yes_no: bool,
name: String,
+ length: usize
}
};
let fields = extract_fields(snippet);
let processed: Vec<_> = fields.iter().map(|field| Field::from(field)).collect();
- assert_eq!(processed.len(), 2);
+ assert_eq!(processed.len(), 3);
assert_eq!(
processed,
@@ -657,6 +796,12 @@ mod test {
ty: Type::TypePath(syn::parse_quote!(String)),
is_a_byte_buf: true,
third_party_type: None,
+ },
+ Field {
+ ident: syn::Ident::new("length", proc_macro2::Span::call_site()),
+ ty: Type::TypePath(syn::parse_quote!(usize)),
+ is_a_byte_buf: false,
+ third_party_type: None,
}
]
)
diff --git a/parquet_derive_test/src/lib.rs b/parquet_derive_test/src/lib.rs
index b4bfc42..bc8e914 100644
--- a/parquet_derive_test/src/lib.rs
+++ b/parquet_derive_test/src/lib.rs
@@ -32,11 +32,18 @@ struct ACompleteRecord<'a> {
pub a_borrowed_string: &'a String,
pub maybe_a_str: Option<&'a str>,
pub maybe_a_string: Option<String>,
- pub magic_number: i32,
- pub low_quality_pi: f32,
- pub high_quality_pi: f64,
- pub maybe_pi: Option<f32>,
- pub maybe_best_pi: Option<f64>,
+ pub i16: i16,
+ pub i32: i32,
+ pub u64: u64,
+ pub maybe_u8: Option<u8>,
+ pub maybe_i16: Option<i16>,
+ pub maybe_u32: Option<u32>,
+ pub maybe_usize: Option<usize>,
+ pub isize: isize,
+ pub float: f32,
+ pub double: f64,
+ pub maybe_float: Option<f32>,
+ pub maybe_double: Option<f64>,
pub borrowed_maybe_a_string: &'a Option<String>,
pub borrowed_maybe_a_str: &'a Option<&'a str>,
}
@@ -57,27 +64,32 @@ mod tests {
#[test]
fn test_parquet_derive_hello() {
let file = get_temp_file("test_parquet_derive_hello", &[]);
- let schema_str = "message schema {
+
+ // The schema is not required, but this tests that the generated
+ // schema agrees with what one would write by hand.
+ let schema_str = "message rust_schema {
REQUIRED boolean a_bool;
- REQUIRED BINARY a_str (UTF8);
- REQUIRED BINARY a_string (UTF8);
- REQUIRED BINARY a_borrowed_string (UTF8);
- OPTIONAL BINARY a_maybe_str (UTF8);
- OPTIONAL BINARY a_maybe_string (UTF8);
- REQUIRED INT32 magic_number;
- REQUIRED FLOAT low_quality_pi;
- REQUIRED DOUBLE high_quality_pi;
- OPTIONAL FLOAT maybe_pi;
- OPTIONAL DOUBLE maybe_best_pi;
- OPTIONAL BINARY borrowed_maybe_a_string (UTF8);
- OPTIONAL BINARY borrowed_maybe_a_str (UTF8);
+ REQUIRED BINARY a_str (STRING);
+ REQUIRED BINARY a_string (STRING);
+ REQUIRED BINARY a_borrowed_string (STRING);
+ OPTIONAL BINARY maybe_a_str (STRING);
+ OPTIONAL BINARY maybe_a_string (STRING);
+ REQUIRED INT32 i16 (INTEGER(16,true));
+ REQUIRED INT32 i32;
+ REQUIRED INT64 u64 (INTEGER(64,false));
+ OPTIONAL INT32 maybe_u8 (INTEGER(8,false));
+ OPTIONAL INT32 maybe_i16 (INTEGER(16,true));
+ OPTIONAL INT32 maybe_u32 (INTEGER(32,false));
+ OPTIONAL INT64 maybe_usize (INTEGER(64,false));
+ REQUIRED INT64 isize (INTEGER(64,true));
+ REQUIRED FLOAT float;
+ REQUIRED DOUBLE double;
+ OPTIONAL FLOAT maybe_float;
+ OPTIONAL DOUBLE maybe_double;
+ OPTIONAL BINARY borrowed_maybe_a_string (STRING);
+ OPTIONAL BINARY borrowed_maybe_a_str (STRING);
}";
- let schema = Arc::new(parse_message_type(schema_str).unwrap());
-
- let props = Arc::new(WriterProperties::builder().build());
- let mut writer = SerializedFileWriter::new(file, schema, props).unwrap();
-
let a_str = "hello mother".to_owned();
let a_borrowed_string = "cool news".to_owned();
let maybe_a_string = Some("it's true, I'm a string".to_owned());
@@ -90,15 +102,31 @@ mod tests {
a_borrowed_string: &a_borrowed_string,
maybe_a_str: Some(&a_str[..]),
maybe_a_string: Some(a_str.clone()),
- magic_number: 100,
- low_quality_pi: 3.14,
- high_quality_pi: 3.1415,
- maybe_pi: Some(3.14),
- maybe_best_pi: Some(3.1415),
+ i16: -45,
+ i32: 456,
+ u64: 4563424,
+ maybe_u8: None,
+ maybe_i16: Some(3),
+ maybe_u32: None,
+ maybe_usize: Some(4456),
+ isize: -365,
+ float: 3.5,
+ double: std::f64::NAN,
+ maybe_float: None,
+ maybe_double: Some(std::f64::MAX),
borrowed_maybe_a_string: &maybe_a_string,
borrowed_maybe_a_str: &maybe_a_str,
}];
+ let schema = Arc::new(parse_message_type(schema_str).unwrap());
+ let generated_schema = drs.as_slice().schema().unwrap();
+
+ assert_eq!(&schema, &generated_schema);
+
+ let props = Arc::new(WriterProperties::builder().build());
+ let mut writer =
+ SerializedFileWriter::new(file, generated_schema, props).unwrap();
+
let mut row_group = writer.next_row_group().unwrap();
drs.as_slice().write_to_row_group(&mut row_group).unwrap();
writer.close_row_group(row_group).unwrap();