You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by al...@apache.org on 2021/07/14 11:51:21 UTC

[arrow-rs] branch master updated: generate parquet schema from rust struct (#539)

This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git


The following commit(s) were added to refs/heads/master by this push:
     new 55a5863  generate parquet schema from rust struct (#539)
55a5863 is described below

commit 55a58634ff9f1090cc9c2db770f84e3502c38b34
Author: Wakahisa <ne...@gmail.com>
AuthorDate: Wed Jul 14 13:51:17 2021 +0200

    generate parquet schema from rust struct (#539)
    
    * generate parquet schema from rust struct
    
    * support all primitive types through logical types
---
 parquet/src/record/record_writer.rs |   5 ++
 parquet_derive/src/lib.rs           |  31 +++++--
 parquet_derive/src/parquet_field.rs | 157 ++++++++++++++++++++++++++++++++++--
 parquet_derive_test/src/lib.rs      |  84 ++++++++++++-------
 4 files changed, 236 insertions(+), 41 deletions(-)

diff --git a/parquet/src/record/record_writer.rs b/parquet/src/record/record_writer.rs
index 56817eb..6668eec 100644
--- a/parquet/src/record/record_writer.rs
+++ b/parquet/src/record/record_writer.rs
@@ -15,6 +15,8 @@
 // specific language governing permissions and limitations
 // under the License.
 
+use crate::schema::types::TypePtr;
+
 use super::super::errors::ParquetError;
 use super::super::file::writer::RowGroupWriter;
 
@@ -23,4 +25,7 @@ pub trait RecordWriter<T> {
         &self,
         row_group_writer: &mut Box<dyn RowGroupWriter>,
     ) -> Result<(), ParquetError>;
+
+    /// Generated schema
+    fn schema(&self) -> Result<TypePtr, ParquetError>;
 }
diff --git a/parquet_derive/src/lib.rs b/parquet_derive/src/lib.rs
index 279d0f7..1c53227 100644
--- a/parquet_derive/src/lib.rs
+++ b/parquet_derive/src/lib.rs
@@ -52,11 +52,6 @@ mod parquet_field;
 ///   pub a_str: &'a str,
 /// }
 ///
-/// let schema_str = "message schema {
-///   REQUIRED boolean         a_bool;
-///   REQUIRED BINARY          a_str (UTF8);
-/// }";
-///
 /// pub fn write_some_records() {
 ///   let samples = vec![
 ///     ACompleteRecord {
@@ -69,7 +64,7 @@ mod parquet_field;
 ///     }
 ///   ];
 ///
-///  let schema = Arc::new(parse_message_type(schema_str).unwrap());
+///  let schema = samples.as_slice().schema();
 ///
 ///  let props = Arc::new(WriterProperties::builder().build());
 ///  let mut writer = SerializedFileWriter::new(file, schema, props).unwrap();
@@ -101,9 +96,15 @@ pub fn parquet_record_writer(input: proc_macro::TokenStream) -> proc_macro::Toke
     let derived_for = input.ident;
     let generics = input.generics;
 
+    let field_types: Vec<proc_macro2::TokenStream> =
+        field_infos.iter().map(|x| x.parquet_type()).collect();
+
     (quote! {
     impl#generics RecordWriter<#derived_for#generics> for &[#derived_for#generics] {
-      fn write_to_row_group(&self, row_group_writer: &mut Box<parquet::file::writer::RowGroupWriter>) -> Result<(), parquet::errors::ParquetError> {
+      fn write_to_row_group(
+        &self,
+        row_group_writer: &mut Box<parquet::file::writer::RowGroupWriter>
+      ) -> Result<(), parquet::errors::ParquetError> {
         let mut row_group_writer = row_group_writer;
         let records = &self; // Used by all the writer snippets to be more clear
 
@@ -121,6 +122,22 @@ pub fn parquet_record_writer(input: proc_macro::TokenStream) -> proc_macro::Toke
 
         Ok(())
       }
+
+      fn schema(&self) -> Result<parquet::schema::types::TypePtr, parquet::errors::ParquetError> {
+        use parquet::schema::types::Type as ParquetType;
+        use parquet::schema::types::TypePtr;
+        use parquet::basic::LogicalType;
+        use parquet::basic::*;
+
+        let mut fields: Vec<TypePtr> = Vec::new();
+        #(
+          #field_types
+        );*;
+        let group = parquet::schema::types::Type::group_type_builder("rust_schema")
+          .with_fields(&mut fields)
+          .build()?;
+        Ok(group.into())
+      }
     }
   }).into()
 }
diff --git a/parquet_derive/src/parquet_field.rs b/parquet_derive/src/parquet_field.rs
index 328f4a6..6f2fa0c 100644
--- a/parquet_derive/src/parquet_field.rs
+++ b/parquet_derive/src/parquet_field.rs
@@ -174,6 +174,50 @@ impl Field {
         }
     }
 
+    pub fn parquet_type(&self) -> proc_macro2::TokenStream {
+        // TODO: Support group types
+        // TODO: Add length if dealing with fixedlenbinary
+
+        let field_name = &self.ident.to_string();
+        let physical_type = match self.ty.physical_type() {
+            parquet::basic::Type::BOOLEAN => quote! {
+                parquet::basic::Type::BOOLEAN
+            },
+            parquet::basic::Type::INT32 => quote! {
+                parquet::basic::Type::INT32
+            },
+            parquet::basic::Type::INT64 => quote! {
+                parquet::basic::Type::INT64
+            },
+            parquet::basic::Type::INT96 => quote! {
+                parquet::basic::Type::INT96
+            },
+            parquet::basic::Type::FLOAT => quote! {
+                parquet::basic::Type::FLOAT
+            },
+            parquet::basic::Type::DOUBLE => quote! {
+                parquet::basic::Type::DOUBLE
+            },
+            parquet::basic::Type::BYTE_ARRAY => quote! {
+                parquet::basic::Type::BYTE_ARRAY
+            },
+            parquet::basic::Type::FIXED_LEN_BYTE_ARRAY => quote! {
+                parquet::basic::Type::FIXED_LEN_BYTE_ARRAY
+            },
+        };
+        let logical_type = self.ty.logical_type();
+        let repetition = self.ty.repetition();
+        quote! {
+            fields.push(ParquetType::primitive_type_builder(#field_name, #physical_type)
+                .with_logical_type(#logical_type)
+                .with_repetition(#repetition)
+                .build()
+                .unwrap()
+                .into()
+            );
+        }
+    }
+
     fn option_into_vals(&self) -> proc_macro2::TokenStream {
         let field_name = &self.ident;
         let is_a_byte_buf = self.is_a_byte_buf;
@@ -201,7 +245,12 @@ impl Field {
         } else if is_a_byte_buf {
             quote! { Some((&inner[..]).into())}
         } else {
-            quote! { Some(inner) }
+            // Type might need converting to a physical type
+            match self.ty.physical_type() {
+                parquet::basic::Type::INT32 => quote! { Some(inner as i32) },
+                parquet::basic::Type::INT64 => quote! { Some(inner as i64) },
+                _ => quote! { Some(inner) },
+            }
         };
 
         quote! {
@@ -232,7 +281,12 @@ impl Field {
         } else if is_a_byte_buf {
             quote! { (&rec.#field_name[..]).into() }
         } else {
-            quote! { rec.#field_name }
+            // Type might need converting to a physical type
+            match self.ty.physical_type() {
+                parquet::basic::Type::INT32 => quote! { rec.#field_name as i32 },
+                parquet::basic::Type::INT64 => quote! { rec.#field_name as i64 },
+                _ => quote! { rec.#field_name },
+            }
         };
 
         quote! {
@@ -403,7 +457,14 @@ impl Type {
             "bool" => BasicType::BOOLEAN,
             "u8" | "u16" | "u32" => BasicType::INT32,
             "i8" | "i16" | "i32" | "NaiveDate" => BasicType::INT32,
-            "u64" | "i64" | "usize" | "NaiveDateTime" => BasicType::INT64,
+            "u64" | "i64" | "NaiveDateTime" => BasicType::INT64,
+            "usize" | "isize" => {
+                if usize::BITS == 64 {
+                    BasicType::INT64
+                } else {
+                    BasicType::INT32
+                }
+            }
             "f32" => BasicType::FLOAT,
             "f64" => BasicType::DOUBLE,
             "String" | "str" | "Uuid" => BasicType::BYTE_ARRAY,
@@ -411,6 +472,83 @@ impl Type {
         }
     }
 
+    fn logical_type(&self) -> proc_macro2::TokenStream {
+        let last_part = self.last_part();
+        let leaf_type = self.leaf_type_recursive();
+
+        match leaf_type {
+            Type::Array(ref first_type) => {
+                if let Type::TypePath(_) = **first_type {
+                    if last_part == "u8" {
+                        return quote! { None };
+                    }
+                }
+            }
+            Type::Vec(ref first_type) => {
+                if let Type::TypePath(_) = **first_type {
+                    if last_part == "u8" {
+                        return quote! { None };
+                    }
+                }
+            }
+            _ => (),
+        }
+
+        match last_part.trim() {
+            "bool" => quote! { None },
+            "u8" => quote! { Some(LogicalType::INTEGER(IntType {
+                bit_width: 8,
+                is_signed: false,
+            })) },
+            "u16" => quote! { Some(LogicalType::INTEGER(IntType {
+                bit_width: 16,
+                is_signed: false,
+            })) },
+            "u32" => quote! { Some(LogicalType::INTEGER(IntType {
+                bit_width: 32,
+                is_signed: false,
+            })) },
+            "u64" => quote! { Some(LogicalType::INTEGER(IntType {
+                bit_width: 64,
+                is_signed: false,
+            })) },
+            "i8" => quote! { Some(LogicalType::INTEGER(IntType {
+                bit_width: 8,
+                is_signed: true,
+            })) },
+            "i16" => quote! { Some(LogicalType::INTEGER(IntType {
+                bit_width: 16,
+                is_signed: true,
+            })) },
+            "i32" | "i64" => quote! { None },
+            "usize" => {
+                quote! { Some(LogicalType::INTEGER(IntType {
+                    bit_width: usize::BITS as i8,
+                    is_signed: false
+                })) }
+            }
+            "isize" => {
+                quote! { Some(LogicalType::INTEGER(IntType {
+                    bit_width: usize::BITS as i8,
+                    is_signed: true
+                })) }
+            }
+            "NaiveDate" => quote! { Some(LogicalType::DATE(Default::default())) },
+            "f32" | "f64" => quote! { None },
+            "String" | "str" => quote! { Some(LogicalType::STRING(Default::default())) },
+            "Uuid" => quote! { Some(LogicalType::UUID(Default::default())) },
+            f => unimplemented!("{} currently is not supported", f),
+        }
+    }
+
+    fn repetition(&self) -> proc_macro2::TokenStream {
+        match &self {
+            Type::Option(_) => quote! { Repetition::OPTIONAL },
+            Type::Reference(_, ty) => ty.repetition(),
+            _ => quote! { Repetition::REQUIRED },
+        }
+    }
+
     /// Convert a parsed rust field AST in to a more easy to manipulate
     /// parquet_derive::Field
     fn from(f: &syn::Field) -> Self {
@@ -505,7 +643,7 @@ mod test {
         assert_eq!(snippet,
                    (quote!{
                         {
-                            let vals : Vec < _ > = records . iter ( ) . map ( | rec | rec . counter ) . collect ( );
+                            let vals : Vec < _ > = records . iter ( ) . map ( | rec | rec . counter as i64 ) . collect ( );
 
                             if let parquet::column::writer::ColumnWriter::Int64ColumnWriter ( ref mut typed ) = column_writer {
                                 typed . write_batch ( & vals [ .. ] , None , None ) ?;
@@ -585,7 +723,7 @@ mod test {
 
                         let vals: Vec <_> = records.iter().filter_map( |rec| {
                             if let Some ( inner ) = rec . optional_dumb_int {
-                                Some ( inner )
+                                Some ( inner as i32 )
                             } else {
                                 None
                             }
@@ -636,12 +774,13 @@ mod test {
           struct ABasicStruct {
             yes_no: bool,
             name: String,
+            length: usize
           }
         };
 
         let fields = extract_fields(snippet);
         let processed: Vec<_> = fields.iter().map(|field| Field::from(field)).collect();
-        assert_eq!(processed.len(), 2);
+        assert_eq!(processed.len(), 3);
 
         assert_eq!(
             processed,
@@ -657,6 +796,12 @@ mod test {
                     ty: Type::TypePath(syn::parse_quote!(String)),
                     is_a_byte_buf: true,
                     third_party_type: None,
+                },
+                Field {
+                    ident: syn::Ident::new("length", proc_macro2::Span::call_site()),
+                    ty: Type::TypePath(syn::parse_quote!(usize)),
+                    is_a_byte_buf: false,
+                    third_party_type: None,
                 }
             ]
         )
diff --git a/parquet_derive_test/src/lib.rs b/parquet_derive_test/src/lib.rs
index b4bfc42..bc8e914 100644
--- a/parquet_derive_test/src/lib.rs
+++ b/parquet_derive_test/src/lib.rs
@@ -32,11 +32,18 @@ struct ACompleteRecord<'a> {
     pub a_borrowed_string: &'a String,
     pub maybe_a_str: Option<&'a str>,
     pub maybe_a_string: Option<String>,
-    pub magic_number: i32,
-    pub low_quality_pi: f32,
-    pub high_quality_pi: f64,
-    pub maybe_pi: Option<f32>,
-    pub maybe_best_pi: Option<f64>,
+    pub i16: i16,
+    pub i32: i32,
+    pub u64: u64,
+    pub maybe_u8: Option<u8>,
+    pub maybe_i16: Option<i16>,
+    pub maybe_u32: Option<u32>,
+    pub maybe_usize: Option<usize>,
+    pub isize: isize,
+    pub float: f32,
+    pub double: f64,
+    pub maybe_float: Option<f32>,
+    pub maybe_double: Option<f64>,
     pub borrowed_maybe_a_string: &'a Option<String>,
     pub borrowed_maybe_a_str: &'a Option<&'a str>,
 }
@@ -57,27 +64,32 @@ mod tests {
     #[test]
     fn test_parquet_derive_hello() {
         let file = get_temp_file("test_parquet_derive_hello", &[]);
-        let schema_str = "message schema {
+
+        // The schema is not required, but this tests that the generated
+        // schema agrees with what one would write by hand.
+        let schema_str = "message rust_schema {
             REQUIRED boolean         a_bool;
-            REQUIRED BINARY          a_str (UTF8);
-            REQUIRED BINARY          a_string (UTF8);
-            REQUIRED BINARY          a_borrowed_string (UTF8);
-            OPTIONAL BINARY          a_maybe_str (UTF8);
-            OPTIONAL BINARY          a_maybe_string (UTF8);
-            REQUIRED INT32           magic_number;
-            REQUIRED FLOAT           low_quality_pi;
-            REQUIRED DOUBLE          high_quality_pi;
-            OPTIONAL FLOAT           maybe_pi;
-            OPTIONAL DOUBLE          maybe_best_pi;
-            OPTIONAL BINARY          borrowed_maybe_a_string (UTF8);
-            OPTIONAL BINARY          borrowed_maybe_a_str (UTF8);
+            REQUIRED BINARY          a_str (STRING);
+            REQUIRED BINARY          a_string (STRING);
+            REQUIRED BINARY          a_borrowed_string (STRING);
+            OPTIONAL BINARY          maybe_a_str (STRING);
+            OPTIONAL BINARY          maybe_a_string (STRING);
+            REQUIRED INT32           i16 (INTEGER(16,true));
+            REQUIRED INT32           i32;
+            REQUIRED INT64           u64 (INTEGER(64,false));
+            OPTIONAL INT32           maybe_u8 (INTEGER(8,false));
+            OPTIONAL INT32           maybe_i16 (INTEGER(16,true));
+            OPTIONAL INT32           maybe_u32 (INTEGER(32,false));
+            OPTIONAL INT64           maybe_usize (INTEGER(64,false));
+            REQUIRED INT64           isize (INTEGER(64,true));
+            REQUIRED FLOAT           float;
+            REQUIRED DOUBLE          double;
+            OPTIONAL FLOAT           maybe_float;
+            OPTIONAL DOUBLE          maybe_double;
+            OPTIONAL BINARY          borrowed_maybe_a_string (STRING);
+            OPTIONAL BINARY          borrowed_maybe_a_str (STRING);
         }";
 
-        let schema = Arc::new(parse_message_type(schema_str).unwrap());
-
-        let props = Arc::new(WriterProperties::builder().build());
-        let mut writer = SerializedFileWriter::new(file, schema, props).unwrap();
-
         let a_str = "hello mother".to_owned();
         let a_borrowed_string = "cool news".to_owned();
         let maybe_a_string = Some("it's true, I'm a string".to_owned());
@@ -90,15 +102,31 @@ mod tests {
             a_borrowed_string: &a_borrowed_string,
             maybe_a_str: Some(&a_str[..]),
             maybe_a_string: Some(a_str.clone()),
-            magic_number: 100,
-            low_quality_pi: 3.14,
-            high_quality_pi: 3.1415,
-            maybe_pi: Some(3.14),
-            maybe_best_pi: Some(3.1415),
+            i16: -45,
+            i32: 456,
+            u64: 4563424,
+            maybe_u8: None,
+            maybe_i16: Some(3),
+            maybe_u32: None,
+            maybe_usize: Some(4456),
+            isize: -365,
+            float: 3.5,
+            double: std::f64::NAN,
+            maybe_float: None,
+            maybe_double: Some(std::f64::MAX),
             borrowed_maybe_a_string: &maybe_a_string,
             borrowed_maybe_a_str: &maybe_a_str,
         }];
 
+        let schema = Arc::new(parse_message_type(schema_str).unwrap());
+        let generated_schema = drs.as_slice().schema().unwrap();
+
+        assert_eq!(&schema, &generated_schema);
+
+        let props = Arc::new(WriterProperties::builder().build());
+        let mut writer =
+            SerializedFileWriter::new(file, generated_schema, props).unwrap();
+
         let mut row_group = writer.next_row_group().unwrap();
         drs.as_slice().write_to_row_group(&mut row_group).unwrap();
         writer.close_row_group(row_group).unwrap();