You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by tu...@apache.org on 2023/06/20 22:52:32 UTC

[arrow-rs] branch master updated: feat: add strict mode to json reader (#4421)

This is an automated email from the ASF dual-hosted git repository.

tustvold pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git


The following commit(s) were added to refs/heads/master by this push:
     new c41dc7f20 feat: add strict mode to json reader (#4421)
c41dc7f20 is described below

commit c41dc7f204087045343449cea4382f7e936e8ee0
Author: Sébastien Brochet <bl...@users.noreply.github.com>
AuthorDate: Wed Jun 21 00:52:26 2023 +0200

    feat: add strict mode to json reader (#4421)
    
    When strict mode is enabled, the parser will return an error if it
    encounters a column not present in the schema
---
 arrow-json/src/reader/list_array.rs   |   2 +
 arrow-json/src/reader/map_array.rs    |   3 +
 arrow-json/src/reader/mod.rs          | 148 +++++++++++++++++++++++++++++-----
 arrow-json/src/reader/struct_array.rs |  24 ++++--
 4 files changed, 152 insertions(+), 25 deletions(-)

diff --git a/arrow-json/src/reader/list_array.rs b/arrow-json/src/reader/list_array.rs
index ad27eb516..d6f7670f2 100644
--- a/arrow-json/src/reader/list_array.rs
+++ b/arrow-json/src/reader/list_array.rs
@@ -35,6 +35,7 @@ impl<O: OffsetSizeTrait> ListArrayDecoder<O> {
     pub fn new(
         data_type: DataType,
         coerce_primitive: bool,
+        strict_mode: bool,
         is_nullable: bool,
     ) -> Result<Self, ArrowError> {
         let field = match &data_type {
@@ -45,6 +46,7 @@ impl<O: OffsetSizeTrait> ListArrayDecoder<O> {
         let decoder = make_decoder(
             field.data_type().clone(),
             coerce_primitive,
+            strict_mode,
             field.is_nullable(),
         )?;
 
diff --git a/arrow-json/src/reader/map_array.rs b/arrow-json/src/reader/map_array.rs
index 2d6fde34d..a1f7e5ace 100644
--- a/arrow-json/src/reader/map_array.rs
+++ b/arrow-json/src/reader/map_array.rs
@@ -34,6 +34,7 @@ impl MapArrayDecoder {
     pub fn new(
         data_type: DataType,
         coerce_primitive: bool,
+        strict_mode: bool,
         is_nullable: bool,
     ) -> Result<Self, ArrowError> {
         let fields = match &data_type {
@@ -56,11 +57,13 @@ impl MapArrayDecoder {
         let keys = make_decoder(
             fields[0].data_type().clone(),
             coerce_primitive,
+            strict_mode,
             fields[0].is_nullable(),
         )?;
         let values = make_decoder(
             fields[1].data_type().clone(),
             coerce_primitive,
+            strict_mode,
             fields[1].is_nullable(),
         )?;
 
diff --git a/arrow-json/src/reader/mod.rs b/arrow-json/src/reader/mod.rs
index dd58e1e1a..4e98e2fd8 100644
--- a/arrow-json/src/reader/mod.rs
+++ b/arrow-json/src/reader/mod.rs
@@ -170,6 +170,7 @@ mod timestamp_array;
 pub struct ReaderBuilder {
     batch_size: usize,
     coerce_primitive: bool,
+    strict_mode: bool,
 
     schema: SchemaRef,
 }
@@ -179,13 +180,15 @@ impl ReaderBuilder {
     ///
     /// This could be obtained using [`infer_json_schema`] if not known
     ///
-    /// Any columns not present in `schema` will be ignored
+    /// Any columns not present in `schema` will be ignored, unless `strict_mode` is set to true.
+    /// In this case, an error is returned when a column is missing from `schema`.
     ///
     /// [`infer_json_schema`]: crate::reader::infer_json_schema
     pub fn new(schema: SchemaRef) -> Self {
         Self {
             batch_size: 1024,
             coerce_primitive: false,
+            strict_mode: false,
             schema,
         }
     }
@@ -211,6 +214,15 @@ impl ReaderBuilder {
         }
     }
 
+    /// Sets if the decoder should return an error if it encounters a column not present
+    /// in `schema`
+    pub fn with_strict_mode(self, strict_mode: bool) -> Self {
+        Self {
+            strict_mode,
+            ..self
+        }
+    }
+
     /// Create a [`Reader`] with the provided [`BufRead`]
     pub fn build<R: BufRead>(self, reader: R) -> Result<Reader<R>, ArrowError> {
         Ok(Reader {
@@ -224,6 +236,7 @@ impl ReaderBuilder {
         let decoder = make_decoder(
             DataType::Struct(self.schema.fields.clone()),
             self.coerce_primitive,
+            self.strict_mode,
             false,
         )?;
         let num_fields = self.schema.all_fields().len();
@@ -586,6 +599,7 @@ macro_rules! primitive_decoder {
 fn make_decoder(
     data_type: DataType,
     coerce_primitive: bool,
+    strict_mode: bool,
     is_nullable: bool,
 ) -> Result<Box<dyn ArrayDecoder>, ArrowError> {
     downcast_integer! {
@@ -633,13 +647,13 @@ fn make_decoder(
         DataType::Boolean => Ok(Box::<BooleanArrayDecoder>::default()),
         DataType::Utf8 => Ok(Box::new(StringArrayDecoder::<i32>::new(coerce_primitive))),
         DataType::LargeUtf8 => Ok(Box::new(StringArrayDecoder::<i64>::new(coerce_primitive))),
-        DataType::List(_) => Ok(Box::new(ListArrayDecoder::<i32>::new(data_type, coerce_primitive, is_nullable)?)),
-        DataType::LargeList(_) => Ok(Box::new(ListArrayDecoder::<i64>::new(data_type, coerce_primitive, is_nullable)?)),
-        DataType::Struct(_) => Ok(Box::new(StructArrayDecoder::new(data_type, coerce_primitive, is_nullable)?)),
+        DataType::List(_) => Ok(Box::new(ListArrayDecoder::<i32>::new(data_type, coerce_primitive, strict_mode, is_nullable)?)),
+        DataType::LargeList(_) => Ok(Box::new(ListArrayDecoder::<i64>::new(data_type, coerce_primitive, strict_mode, is_nullable)?)),
+        DataType::Struct(_) => Ok(Box::new(StructArrayDecoder::new(data_type, coerce_primitive, strict_mode, is_nullable)?)),
         DataType::Binary | DataType::LargeBinary | DataType::FixedSizeBinary(_) => {
             Err(ArrowError::JsonError(format!("{data_type} is not supported by JSON")))
         }
-        DataType::Map(_, _) => Ok(Box::new(MapArrayDecoder::new(data_type, coerce_primitive, is_nullable)?)),
+        DataType::Map(_, _) => Ok(Box::new(MapArrayDecoder::new(data_type, coerce_primitive, strict_mode, is_nullable)?)),
         d => Err(ArrowError::NotYetImplemented(format!("Support for {d} in JSON reader")))
     }
 }
@@ -670,6 +684,7 @@ mod tests {
         buf: &str,
         batch_size: usize,
         coerce_primitive: bool,
+        strict_mode: bool,
         schema: SchemaRef,
     ) -> Vec<RecordBatch> {
         let mut unbuffered = vec![];
@@ -693,6 +708,7 @@ mod tests {
                 let buffered = ReaderBuilder::new(schema.clone())
                     .with_batch_size(batch_size)
                     .with_coerce_primitive(coerce_primitive)
+                    .with_strict_mode(strict_mode)
                     .build(BufReader::with_capacity(b, Cursor::new(buf.as_bytes())))
                     .unwrap()
                     .collect::<Result<Vec<_>, _>>()
@@ -724,7 +740,7 @@ mod tests {
             Field::new("e", DataType::Date64, true),
         ]));
 
-        let batches = do_read(buf, 1024, false, schema);
+        let batches = do_read(buf, 1024, false, false, schema);
         assert_eq!(batches.len(), 1);
 
         let col1 = batches[0].column(0).as_primitive::<Int64Type>();
@@ -763,7 +779,7 @@ mod tests {
         {"a": "1", "b": "2"}
         {"a": "hello", "b": "shoo"}
         {"b": "\t😁foo", "a": "\nfoobar\ud83d\ude00\u0061\u0073\u0066\u0067\u00FF"}
-        
+
         {"b": null}
         {"b": "", "a": null}
 
@@ -773,7 +789,7 @@ mod tests {
             Field::new("b", DataType::LargeUtf8, true),
         ]));
 
-        let batches = do_read(buf, 1024, false, schema);
+        let batches = do_read(buf, 1024, false, false, schema);
         assert_eq!(batches.len(), 1);
 
         let col1 = batches[0].column(0).as_string::<i32>();
@@ -826,7 +842,7 @@ mod tests {
             ),
         ]));
 
-        let batches = do_read(buf, 1024, false, schema);
+        let batches = do_read(buf, 1024, false, false, schema);
         assert_eq!(batches.len(), 1);
 
         let list = batches[0].column(0).as_list::<i32>();
@@ -895,7 +911,7 @@ mod tests {
             ),
         ]));
 
-        let batches = do_read(buf, 1024, false, schema);
+        let batches = do_read(buf, 1024, false, false, schema);
         assert_eq!(batches.len(), 1);
 
         let nested = batches[0].column(0).as_struct();
@@ -941,7 +957,7 @@ mod tests {
 
         let schema = Arc::new(Schema::new(vec![map]));
 
-        let batches = do_read(buf, 1024, false, schema);
+        let batches = do_read(buf, 1024, false, false, schema);
         assert_eq!(batches.len(), 1);
 
         let map = batches[0].column(0).as_map();
@@ -1015,7 +1031,7 @@ mod tests {
             Field::new("c", DataType::Utf8, true),
         ]));
 
-        let batches = do_read(buf, 1024, true, schema);
+        let batches = do_read(buf, 1024, true, false, schema);
         assert_eq!(batches.len(), 1);
 
         let col1 = batches[0].column(0).as_string::<i32>();
@@ -1063,7 +1079,7 @@ mod tests {
             Field::new("c", data_type, true),
         ]));
 
-        let batches = do_read(buf, 1024, true, schema);
+        let batches = do_read(buf, 1024, true, false, schema);
         assert_eq!(batches.len(), 1);
 
         let col1 = batches[0].column(0).as_primitive::<T>();
@@ -1121,7 +1137,7 @@ mod tests {
             Field::new("d", with_timezone, true),
         ]));
 
-        let batches = do_read(buf, 1024, true, schema);
+        let batches = do_read(buf, 1024, true, false, schema);
         assert_eq!(batches.len(), 1);
 
         let unit_in_nanos: i64 = match T::UNIT {
@@ -1221,7 +1237,7 @@ mod tests {
             Field::new("c", T::DATA_TYPE, true),
         ]));
 
-        let batches = do_read(buf, 1024, true, schema);
+        let batches = do_read(buf, 1024, true, false, schema);
         assert_eq!(batches.len(), 1);
 
         let col1 = batches[0].column(0).as_primitive::<T>();
@@ -1298,7 +1314,7 @@ mod tests {
             ),
         ]));
 
-        let batches = do_read(json, 1024, true, schema);
+        let batches = do_read(json, 1024, true, false, schema);
         assert_eq!(batches.len(), 1);
 
         let s: StructArray = batches.into_iter().next().unwrap().into();
@@ -1373,7 +1389,7 @@ mod tests {
             Field::new("u64", DataType::UInt64, true),
         ]));
 
-        let batches = do_read(buf, 1024, true, schema);
+        let batches = do_read(buf, 1024, true, false, schema);
         assert_eq!(batches.len(), 1);
 
         let i64 = batches[0].column(0).as_primitive::<Int64Type>();
@@ -1397,7 +1413,7 @@ mod tests {
             true,
         )]));
 
-        let batches = do_read(buf, 1024, true, schema);
+        let batches = do_read(buf, 1024, true, false, schema);
         assert_eq!(batches.len(), 1);
 
         let i64 = batches[0]
@@ -1406,6 +1422,98 @@ mod tests {
         assert_eq!(i64.values(), &[i64::MAX, i64::MIN, 900000]);
     }
 
+    #[test]
+    fn test_strict_mode_no_missing_columns_in_schema() {
+        let buf = r#"
+        {"a": 1, "b": "2", "c": true}
+        {"a": 2E0, "b": "4", "c": false}
+        "#;
+
+        let schema = Arc::new(Schema::new(vec![
+            Field::new("a", DataType::Int16, false),
+            Field::new("b", DataType::Utf8, false),
+            Field::new("c", DataType::Boolean, false),
+        ]));
+
+        let batches = do_read(buf, 1024, true, true, schema);
+        assert_eq!(batches.len(), 1);
+
+        let buf = r#"
+        {"a": 1, "b": "2", "c": {"a": true, "b": 1}}
+        {"a": 2E0, "b": "4", "c": {"a": false, "b": 2}}
+        "#;
+
+        let schema = Arc::new(Schema::new(vec![
+            Field::new("a", DataType::Int16, false),
+            Field::new("b", DataType::Utf8, false),
+            Field::new_struct(
+                "c",
+                vec![
+                    Field::new("a", DataType::Boolean, false),
+                    Field::new("b", DataType::Int16, false),
+                ],
+                false,
+            ),
+        ]));
+
+        let batches = do_read(buf, 1024, true, true, schema);
+        assert_eq!(batches.len(), 1);
+    }
+
+    #[test]
+    fn test_strict_mode_missing_columns_in_schema() {
+        let buf = r#"
+        {"a": 1, "b": "2", "c": true}
+        {"a": 2E0, "b": "4", "c": false}
+        "#;
+
+        let schema = Arc::new(Schema::new(vec![
+            Field::new("a", DataType::Int16, true),
+            Field::new("c", DataType::Boolean, true),
+        ]));
+
+        let err = ReaderBuilder::new(schema)
+            .with_batch_size(1024)
+            .with_strict_mode(true)
+            .build(Cursor::new(buf.as_bytes()))
+            .unwrap()
+            .read()
+            .unwrap_err();
+
+        assert_eq!(
+            err.to_string(),
+            "Json error: column 'b' missing from schema"
+        );
+
+        let buf = r#"
+        {"a": 1, "b": "2", "c": {"a": true, "b": 1}}
+        {"a": 2E0, "b": "4", "c": {"a": false, "b": 2}}
+        "#;
+
+        let schema = Arc::new(Schema::new(vec![
+            Field::new("a", DataType::Int16, false),
+            Field::new("b", DataType::Utf8, false),
+            Field::new_struct(
+                "c",
+                vec![Field::new("a", DataType::Boolean, false)],
+                false,
+            ),
+        ]));
+
+        let err = ReaderBuilder::new(schema)
+            .with_batch_size(1024)
+            .with_strict_mode(true)
+            .build(Cursor::new(buf.as_bytes()))
+            .unwrap()
+            .read()
+            .unwrap_err();
+
+        assert_eq!(
+            err.to_string(),
+            "Json error: whilst decoding field 'c': column 'b' missing from schema"
+        );
+    }
+
     fn read_file(path: &str, schema: Option<Schema>) -> Reader<BufReader<File>> {
         let file = File::open(path).unwrap();
         let mut reader = BufReader::new(file);
@@ -1628,7 +1736,7 @@ mod tests {
             true,
         )]));
 
-        let batches = do_read(json_content, 1024, false, schema);
+        let batches = do_read(json_content, 1024, false, false, schema);
         assert_eq!(batches.len(), 1);
 
         let col1 = batches[0].column(0).as_list::<i32>();
@@ -1656,7 +1764,7 @@ mod tests {
             true,
         )]));
 
-        let batches = do_read(json_content, 1024, false, schema);
+        let batches = do_read(json_content, 1024, false, false, schema);
         assert_eq!(batches.len(), 1);
 
         let col1 = batches[0].column(0).as_list::<i32>();
diff --git a/arrow-json/src/reader/struct_array.rs b/arrow-json/src/reader/struct_array.rs
index 3d24a927d..77d7e170d 100644
--- a/arrow-json/src/reader/struct_array.rs
+++ b/arrow-json/src/reader/struct_array.rs
@@ -25,6 +25,7 @@ use arrow_schema::{ArrowError, DataType, Fields};
 pub struct StructArrayDecoder {
     data_type: DataType,
     decoders: Vec<Box<dyn ArrayDecoder>>,
+    strict_mode: bool,
     is_nullable: bool,
 }
 
@@ -32,6 +33,7 @@ impl StructArrayDecoder {
     pub fn new(
         data_type: DataType,
         coerce_primitive: bool,
+        strict_mode: bool,
         is_nullable: bool,
     ) -> Result<Self, ArrowError> {
         let decoders = struct_fields(&data_type)
@@ -41,13 +43,19 @@ impl StructArrayDecoder {
                 // StructArrayDecoder::decode verifies that if the child is not nullable
                 // it doesn't contain any nulls not masked by its parent
                 let nullable = f.is_nullable() || is_nullable;
-                make_decoder(f.data_type().clone(), coerce_primitive, nullable)
+                make_decoder(
+                    f.data_type().clone(),
+                    coerce_primitive,
+                    strict_mode,
+                    nullable,
+                )
             })
             .collect::<Result<Vec<_>, ArrowError>>()?;
 
         Ok(Self {
             data_type,
             decoders,
+            strict_mode,
             is_nullable,
         })
     }
@@ -86,10 +94,16 @@ impl ArrayDecoder for StructArrayDecoder {
                 };
 
                 // Update child pos if match found
-                if let Some(field_idx) =
-                    fields.iter().position(|x| x.name() == field_name)
-                {
-                    child_pos[field_idx][row] = cur_idx + 1;
+                match fields.iter().position(|x| x.name() == field_name) {
+                    Some(field_idx) => child_pos[field_idx][row] = cur_idx + 1,
+                    None => {
+                        if self.strict_mode {
+                            return Err(ArrowError::JsonError(format!(
+                                "column '{}' missing from schema",
+                                field_name
+                            )));
+                        }
+                    }
                 }
 
                 // Advance to next field