You are viewing a plain text version of this content. The canonical link for it is here.
Posted to dev@avro.apache.org by Simon Gittins <sg...@gmail.com> on 2023/11/04 01:25:19 UTC

[rust] Better support for union schema (patch included)

Hello

I've written some failing rust unit tests showing:
- Non trivial union schemas (without null variant) failing to encode/decode
round-trip (serde)
- Non trivial union schemas including null variant at position 0 failing to
round trip (serde)
- Non trivial union schemas including null varant at position other than 0
failing to round trip (serde)

I've also attached and appended an update to the encoder that makes the
above tests pass.  Is this a patch that the team is interested in?

Thanks
Simon

diff --git a/lang/rust/avro/src/encode.rs b/lang/rust/avro/src/encode.rs
index 4593779ac..829a8ee6c 100644
--- a/lang/rust/avro/src/encode.rs
+++ b/lang/rust/avro/src/encode.rs
@@ -19,7 +19,7 @@ use crate::{
     decimal::serialize_big_decimal,
     schema::{
         DecimalSchema, EnumSchema, FixedSchema, Name, Namespace,
RecordSchema, ResolvedSchema,
-        Schema, SchemaKind,
+        UnionSchema, Schema, SchemaKind,
     },
     types::{Value, ValueKind},
     util::{zig_i32, zig_i64},
@@ -71,7 +71,20 @@ pub(crate) fn encode_internal<S: Borrow<Schema>>(
     }

     match value {
-        Value::Null => (),
+        Value::Null => {
+            match schema {
+                Schema::Union(s) => {
+                    match s.schemas.iter().position(|sch|*sch ==
Schema::Null) {
+                        None =>
+                            return Err(Error::EncodeValueAsSchemaError {
+                                value_kind: ValueKind::Null,
+                                supported_schema: vec![SchemaKind::Null,
SchemaKind::Union], }),
+                        Some(p) => encode_long(p as i64, buffer),
+                    }
+                }
+                _ => ()
+            }
+        },
         Value::Boolean(b) => buffer.push(u8::from(*b)),
         // Pattern | Pattern here to signify that these _must_ have the
same encoding.
         Value::Int(i) | Value::Date(i) | Value::TimeMillis(i) =>
encode_int(*i, buffer),
@@ -242,6 +255,21 @@ pub(crate) fn encode_internal<S: Borrow<Schema>>(
                         ));
                     }
                 }
+            } else if let Schema::Union(UnionSchema{ schemas, .. }) =
schema {
+                let original_size = buffer.len();
+                for (index,s) in schemas.iter().enumerate() {
+                    encode_long(index as i64, buffer);
+                    match encode_internal(value, s.borrow(), names,
enclosing_namespace, buffer) {
+                        Ok(_) => return Ok(()),
+                        Err(e) => {
+                            buffer.truncate(original_size); //undo any
partial encoding
+                        }
+                    }
+                }
+                return Err(Error::EncodeValueAsSchemaError {
+                    value_kind: ValueKind::Record,
+                    supported_schema: vec![SchemaKind::Record,
SchemaKind::Union],
+                });
             } else {
                 error!("invalid schema type for Record: {:?}", schema);
                 return Err(Error::EncodeValueAsSchemaError {
diff --git a/lang/rust/avro/tests/union_schema.rs b/lang/rust/avro/tests/
union_schema.rs
index e69de29bb..1dc19d25c 100644
--- a/lang/rust/avro/tests/union_schema.rs
+++ b/lang/rust/avro/tests/union_schema.rs
@@ -0,0 +1,185 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use serde::{Deserialize, Serialize};
+use serde::de::DeserializeOwned;
+use apache_avro::{from_avro_datum, to_avro_datum, to_value, from_value,
types, Schema, Writer, Reader, Codec};
+
+
+static SCHEMA_A_STR: &str = r#"{
+        "name": "A",
+        "type": "record",
+        "fields": [
+            {"name": "field_a", "type": "float"}
+        ]
+    }"#;
+
+static SCHEMA_B_STR: &str = r#"{
+        "name": "B",
+        "type": "record",
+        "fields": [
+            {"name": "field_b", "type": "long"}
+        ]
+    }"#;
+
+static SCHEMA_C_STR: &str = r#"{
+        "name": "C",
+        "type": "record",
+        "fields": [
+            {"name": "field_union", "type": ["A", "B"]},
+            {"name": "field_c", "type": "string"}
+        ]
+    }"#;
+
+#[derive(Serialize,Deserialize,Clone, PartialEq, Debug)]
+struct A {
+    field_a: f32,
+}
+
+#[derive(Serialize,Deserialize,Clone, PartialEq, Debug)]
+struct B {
+    field_b: i64,
+}
+
+#[derive(Serialize,Deserialize,Clone, PartialEq, Debug)]
+#[serde(untagged)]
+enum UnionAB {
+    A(A),
+    B(B),
+}
+
+#[derive(Serialize,Deserialize,Clone, PartialEq, Debug)]
+struct C {
+    field_union: UnionAB,
+    field_c: String
+}
+
+fn encode_decode<T> (input: &T,schema: &Schema,schemata: &Vec<Schema>) -> T
+    where T: DeserializeOwned + Serialize {
+    let mut encoded: Vec<u8> = Vec::new();
+    let mut writer = Writer::with_schemata(&schema,
schemata.iter().collect(), &mut encoded, Codec::Null);
+    writer.append_ser((input.clone())).unwrap();
+    writer.flush().unwrap();
+
+    let mut reader = Reader::with_schemata(schema,
schemata.iter().collect(), encoded.as_slice()).unwrap();
+    from_value::<T>(&reader.next().unwrap().unwrap()).unwrap()
+}
+
+
+#[test]
+fn union_schema_round_trip_no_null()  {
+    let schemata: Vec<Schema> = Schema::parse_list(&[SCHEMA_A_STR,
SCHEMA_B_STR, SCHEMA_C_STR]).expect("parsing schemata");
+
+    {
+        let input = C { field_union: (UnionAB::A(A { field_a: 45.5 })),
field_c: "foo".to_string() };
+        let output = encode_decode(&input,&schemata[2],&schemata);
+        assert_eq!(input,output);
+    }
+    {
+        let input = C { field_union: (UnionAB::B(B { field_b: 73 })),
field_c: "bar".to_string() };
+        let output = encode_decode(&input,&schemata[2],&schemata);
+        assert_eq!(input,output);
+    }
+}
+
+static SCHEMA_D_STR: &str = r#"{
+        "name": "D",
+        "type": "record",
+        "fields": [
+            {"name": "field_union", "type": ["null", "A", "B"]},
+            {"name": "field_d", "type": "string"}
+        ]
+    }"#;
+
+#[derive(Serialize,Deserialize,Clone, PartialEq, Debug)]
+#[serde(untagged)]
+enum UnionNoneAB {
+    None,
+    A(A),
+    B(B),
+}
+
+#[derive(Serialize,Deserialize,Clone, PartialEq, Debug)]
+struct D {
+    field_union: UnionNoneAB,
+    field_d: String
+}
+
+#[test]
+fn union_schema_round_trip_null_at_start()  {
+    let schemata: Vec<Schema> = Schema::parse_list(&[SCHEMA_A_STR,
SCHEMA_B_STR, SCHEMA_D_STR]).expect("parsing schemata");
+
+    {
+        let input = D { field_union: UnionNoneAB::A(A { field_a: 54.25 }),
field_d: "fooy".to_string() };
+        let output = encode_decode(&input,&schemata[2],&schemata);
+        assert_eq!(input,output);
+    }
+    {
+        let input = D { field_union: UnionNoneAB::None, field_d:
"fooyy".to_string() };
+        let output = encode_decode(&input,&schemata[2],&schemata);
+        assert_eq!(input,output);
+    }
+    {
+        let input = D { field_union: UnionNoneAB::B(B { field_b: 103 }),
field_d: "foov".to_string() };
+        let output = encode_decode(&input,&schemata[2],&schemata);
+        assert_eq!(input,output);
+    }
+}
+
+static SCHEMA_E_STR: &str = r#"{
+        "name": "E",
+        "type": "record",
+        "fields": [
+            {"name": "field_union", "type": ["A", "null", "B"]},
+            {"name": "field_e", "type": "string"}
+        ]
+    }"#;
+
+#[derive(Serialize,Deserialize,Clone, PartialEq, Debug)]
+#[serde(untagged)]
+enum UnionANoneB {
+    A(A),
+    None,
+    B(B),
+}
+
+#[derive(Serialize,Deserialize,Clone, PartialEq, Debug)]
+struct E {
+    field_union: UnionANoneB,
+    field_e: String
+}
+
+#[test]
+fn union_schema_round_trip_with_out_of_order_null()  {
+    let schemata: Vec<Schema> = Schema::parse_list(&[SCHEMA_A_STR,
SCHEMA_B_STR, SCHEMA_E_STR]).expect("parsing schemata");
+
+    {
+        let input = E { field_union: UnionANoneB::A(A { field_a: 23.75 }),
field_e: "barme".to_string() };
+        let output = encode_decode(&input,&schemata[2],&schemata);
+        assert_eq!(input,output);
+    }
+    {
+        let input = E { field_union: UnionANoneB::None, field_e:
"barme2".to_string() };
+        let output = encode_decode(&input,&schemata[2],&schemata);
+        assert_eq!(input,output);
+    }
+    {
+        let input = E { field_union: UnionANoneB::B(B { field_b: 89 }),
field_e: "barme3".to_string() };
+        let output = encode_decode(&input,&schemata[2],&schemata);
+        assert_eq!(input,output);
+    }
+}

Re: [rust] Better support for union schema (patch included)

Posted by Martin Grigorov <mg...@apache.org>.
Hi Simon,

Yes, we are interested in any and all kinds of improvements!
Please create a JIRA issue and send a Pull Request with the diff above!
Thank you!

Martin

On Sat, Nov 4, 2023 at 3:25 AM Simon Gittins <sg...@gmail.com> wrote:

> Hello
>
> I've written some failing rust unit tests showing:
> - Non trivial union schemas (without null variant) failing to
> encode/decode round-trip (serde)
> - Non trivial union schemas including null variant at position 0 failing
> to round trip (serde)
> - Non trivial union schemas including null varant at position other than 0
> failing to round trip (serde)
>
> I've also attached and appended an update to the encoder that makes the
> above tests pass.  Is this a patch that the team is interested in?
>
> Thanks
> Simon
>
> diff --git a/lang/rust/avro/src/encode.rs b/lang/rust/avro/src/encode.rs
> index 4593779ac..829a8ee6c 100644
> --- a/lang/rust/avro/src/encode.rs
> +++ b/lang/rust/avro/src/encode.rs
> @@ -19,7 +19,7 @@ use crate::{
>      decimal::serialize_big_decimal,
>      schema::{
>          DecimalSchema, EnumSchema, FixedSchema, Name, Namespace,
> RecordSchema, ResolvedSchema,
> -        Schema, SchemaKind,
> +        UnionSchema, Schema, SchemaKind,
>      },
>      types::{Value, ValueKind},
>      util::{zig_i32, zig_i64},
> @@ -71,7 +71,20 @@ pub(crate) fn encode_internal<S: Borrow<Schema>>(
>      }
>
>      match value {
> -        Value::Null => (),
> +        Value::Null => {
> +            match schema {
> +                Schema::Union(s) => {
> +                    match s.schemas.iter().position(|sch|*sch ==
> Schema::Null) {
> +                        None =>
> +                            return Err(Error::EncodeValueAsSchemaError {
> +                                value_kind: ValueKind::Null,
> +                                supported_schema: vec![SchemaKind::Null,
> SchemaKind::Union], }),
> +                        Some(p) => encode_long(p as i64, buffer),
> +                    }
> +                }
> +                _ => ()
> +            }
> +        },
>          Value::Boolean(b) => buffer.push(u8::from(*b)),
>          // Pattern | Pattern here to signify that these _must_ have the
> same encoding.
>          Value::Int(i) | Value::Date(i) | Value::TimeMillis(i) =>
> encode_int(*i, buffer),
> @@ -242,6 +255,21 @@ pub(crate) fn encode_internal<S: Borrow<Schema>>(
>                          ));
>                      }
>                  }
> +            } else if let Schema::Union(UnionSchema{ schemas, .. }) =
> schema {
> +                let original_size = buffer.len();
> +                for (index,s) in schemas.iter().enumerate() {
> +                    encode_long(index as i64, buffer);
> +                    match encode_internal(value, s.borrow(), names,
> enclosing_namespace, buffer) {
> +                        Ok(_) => return Ok(()),
> +                        Err(e) => {
> +                            buffer.truncate(original_size); //undo any
> partial encoding
> +                        }
> +                    }
> +                }
> +                return Err(Error::EncodeValueAsSchemaError {
> +                    value_kind: ValueKind::Record,
> +                    supported_schema: vec![SchemaKind::Record,
> SchemaKind::Union],
> +                });
>              } else {
>                  error!("invalid schema type for Record: {:?}", schema);
>                  return Err(Error::EncodeValueAsSchemaError {
> diff --git a/lang/rust/avro/tests/union_schema.rs b/lang/rust/avro/tests/
> union_schema.rs
> index e69de29bb..1dc19d25c 100644
> --- a/lang/rust/avro/tests/union_schema.rs
> +++ b/lang/rust/avro/tests/union_schema.rs
> @@ -0,0 +1,185 @@
> +// Licensed to the Apache Software Foundation (ASF) under one
> +// or more contributor license agreements.  See the NOTICE file
> +// distributed with this work for additional information
> +// regarding copyright ownership.  The ASF licenses this file
> +// to you under the Apache License, Version 2.0 (the
> +// "License"); you may not use this file except in compliance
> +// with the License.  You may obtain a copy of the License at
> +//
> +//   http://www.apache.org/licenses/LICENSE-2.0
> +//
> +// Unless required by applicable law or agreed to in writing,
> +// software distributed under the License is distributed on an
> +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
> +// KIND, either express or implied.  See the License for the
> +// specific language governing permissions and limitations
> +// under the License.
> +
> +use serde::{Deserialize, Serialize};
> +use serde::de::DeserializeOwned;
> +use apache_avro::{from_avro_datum, to_avro_datum, to_value, from_value,
> types, Schema, Writer, Reader, Codec};
> +
> +
> +static SCHEMA_A_STR: &str = r#"{
> +        "name": "A",
> +        "type": "record",
> +        "fields": [
> +            {"name": "field_a", "type": "float"}
> +        ]
> +    }"#;
> +
> +static SCHEMA_B_STR: &str = r#"{
> +        "name": "B",
> +        "type": "record",
> +        "fields": [
> +            {"name": "field_b", "type": "long"}
> +        ]
> +    }"#;
> +
> +static SCHEMA_C_STR: &str = r#"{
> +        "name": "C",
> +        "type": "record",
> +        "fields": [
> +            {"name": "field_union", "type": ["A", "B"]},
> +            {"name": "field_c", "type": "string"}
> +        ]
> +    }"#;
> +
> +#[derive(Serialize,Deserialize,Clone, PartialEq, Debug)]
> +struct A {
> +    field_a: f32,
> +}
> +
> +#[derive(Serialize,Deserialize,Clone, PartialEq, Debug)]
> +struct B {
> +    field_b: i64,
> +}
> +
> +#[derive(Serialize,Deserialize,Clone, PartialEq, Debug)]
> +#[serde(untagged)]
> +enum UnionAB {
> +    A(A),
> +    B(B),
> +}
> +
> +#[derive(Serialize,Deserialize,Clone, PartialEq, Debug)]
> +struct C {
> +    field_union: UnionAB,
> +    field_c: String
> +}
> +
> +fn encode_decode<T> (input: &T,schema: &Schema,schemata: &Vec<Schema>) ->
> T
> +    where T: DeserializeOwned + Serialize {
> +    let mut encoded: Vec<u8> = Vec::new();
> +    let mut writer = Writer::with_schemata(&schema,
> schemata.iter().collect(), &mut encoded, Codec::Null);
> +    writer.append_ser((input.clone())).unwrap();
> +    writer.flush().unwrap();
> +
> +    let mut reader = Reader::with_schemata(schema,
> schemata.iter().collect(), encoded.as_slice()).unwrap();
> +    from_value::<T>(&reader.next().unwrap().unwrap()).unwrap()
> +}
> +
> +
> +#[test]
> +fn union_schema_round_trip_no_null()  {
> +    let schemata: Vec<Schema> = Schema::parse_list(&[SCHEMA_A_STR,
> SCHEMA_B_STR, SCHEMA_C_STR]).expect("parsing schemata");
> +
> +    {
> +        let input = C { field_union: (UnionAB::A(A { field_a: 45.5 })),
> field_c: "foo".to_string() };
> +        let output = encode_decode(&input,&schemata[2],&schemata);
> +        assert_eq!(input,output);
> +    }
> +    {
> +        let input = C { field_union: (UnionAB::B(B { field_b: 73 })),
> field_c: "bar".to_string() };
> +        let output = encode_decode(&input,&schemata[2],&schemata);
> +        assert_eq!(input,output);
> +    }
> +}
> +
> +static SCHEMA_D_STR: &str = r#"{
> +        "name": "D",
> +        "type": "record",
> +        "fields": [
> +            {"name": "field_union", "type": ["null", "A", "B"]},
> +            {"name": "field_d", "type": "string"}
> +        ]
> +    }"#;
> +
> +#[derive(Serialize,Deserialize,Clone, PartialEq, Debug)]
> +#[serde(untagged)]
> +enum UnionNoneAB {
> +    None,
> +    A(A),
> +    B(B),
> +}
> +
> +#[derive(Serialize,Deserialize,Clone, PartialEq, Debug)]
> +struct D {
> +    field_union: UnionNoneAB,
> +    field_d: String
> +}
> +
> +#[test]
> +fn union_schema_round_trip_null_at_start()  {
> +    let schemata: Vec<Schema> = Schema::parse_list(&[SCHEMA_A_STR,
> SCHEMA_B_STR, SCHEMA_D_STR]).expect("parsing schemata");
> +
> +    {
> +        let input = D { field_union: UnionNoneAB::A(A { field_a: 54.25
> }), field_d: "fooy".to_string() };
> +        let output = encode_decode(&input,&schemata[2],&schemata);
> +        assert_eq!(input,output);
> +    }
> +    {
> +        let input = D { field_union: UnionNoneAB::None, field_d:
> "fooyy".to_string() };
> +        let output = encode_decode(&input,&schemata[2],&schemata);
> +        assert_eq!(input,output);
> +    }
> +    {
> +        let input = D { field_union: UnionNoneAB::B(B { field_b: 103 }),
> field_d: "foov".to_string() };
> +        let output = encode_decode(&input,&schemata[2],&schemata);
> +        assert_eq!(input,output);
> +    }
> +}
> +
> +static SCHEMA_E_STR: &str = r#"{
> +        "name": "E",
> +        "type": "record",
> +        "fields": [
> +            {"name": "field_union", "type": ["A", "null", "B"]},
> +            {"name": "field_e", "type": "string"}
> +        ]
> +    }"#;
> +
> +#[derive(Serialize,Deserialize,Clone, PartialEq, Debug)]
> +#[serde(untagged)]
> +enum UnionANoneB {
> +    A(A),
> +    None,
> +    B(B),
> +}
> +
> +#[derive(Serialize,Deserialize,Clone, PartialEq, Debug)]
> +struct E {
> +    field_union: UnionANoneB,
> +    field_e: String
> +}
> +
> +#[test]
> +fn union_schema_round_trip_with_out_of_order_null()  {
> +    let schemata: Vec<Schema> = Schema::parse_list(&[SCHEMA_A_STR,
> SCHEMA_B_STR, SCHEMA_E_STR]).expect("parsing schemata");
> +
> +    {
> +        let input = E { field_union: UnionANoneB::A(A { field_a: 23.75
> }), field_e: "barme".to_string() };
> +        let output = encode_decode(&input,&schemata[2],&schemata);
> +        assert_eq!(input,output);
> +    }
> +    {
> +        let input = E { field_union: UnionANoneB::None, field_e:
> "barme2".to_string() };
> +        let output = encode_decode(&input,&schemata[2],&schemata);
> +        assert_eq!(input,output);
> +    }
> +    {
> +        let input = E { field_union: UnionANoneB::B(B { field_b: 89 }),
> field_e: "barme3".to_string() };
> +        let output = encode_decode(&input,&schemata[2],&schemata);
> +        assert_eq!(input,output);
> +    }
> +}
>
>
>
>