You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@avro.apache.org by ko...@apache.org on 2021/05/21 13:11:22 UTC
[avro] branch master updated: AVRO-3104 Allow Comparing Schemas to
non-Schemas (#1221)
This is an automated email from the ASF dual-hosted git repository.
kojiromike pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/avro.git
The following commit(s) were added to refs/heads/master by this push:
new 72a0fef AVRO-3104 Allow Comparing Schemas to non-Schemas (#1221)
72a0fef is described below
commit 72a0fef3e9247bd94621288fb2e6d88408e3c836
Author: Michael A. Smith <mi...@smith-li.com>
AuthorDate: Fri May 21 09:11:07 2021 -0400
AVRO-3104 Allow Comparing Schemas to non-Schemas (#1221)
Avro no longer crashes with an `AttributeError` when you test if a schema is equal to an object that is not a schema, as in `some_schema == not_a_schema`. The result is always `False`.
---
lang/py/avro/schema.py | 102 +++++++++++++++------------------------
lang/py/avro/test/test_schema.py | 5 ++
2 files changed, 44 insertions(+), 63 deletions(-)
diff --git a/lang/py/avro/schema.py b/lang/py/avro/schema.py
index 7999d0f..43531a3 100644
--- a/lang/py/avro/schema.py
+++ b/lang/py/avro/schema.py
@@ -148,7 +148,26 @@ def _is_timezone_aware_datetime(dt):
# Base Classes
#
-class CanonicalPropertiesMixin(object):
+class EqualByJsonMixin:
+ """Equal if the json serializations are equal."""
+ def __eq__(self, that):
+ try:
+ that_str = json.loads(str(that))
+ except json.decoder.JSONDecodeError:
+ return False
+ return json.loads(str(self)) == that_str
+
+
+class EqualByPropsMixin:
+ """Equal if the props are equal."""
+ def __eq__(self, that):
+ try:
+ return self.props == that.props
+ except AttributeError:
+ return False
+
+
+class CanonicalPropertiesMixin:
"""A Mixin that provides canonical properties to Schema and Field types."""
@property
def canonical_properties(self):
@@ -253,6 +272,13 @@ class Schema(abc.ABC, CanonicalPropertiesMixin):
# The separators eliminate whitespace around commas and colons.
return json.dumps(self.to_canonical_json(), separators=(",", ":"))
+ @abc.abstractmethod
+ def __eq__(self, that):
+ """
+ Determines how two schema are compared.
+ Consider the mixins EqualByPropsMixin and EqualByJsonMixin
+ """
+
class Name:
"""Class to describe Avro name."""
@@ -301,7 +327,10 @@ class Name:
def __eq__(self, other):
"""Equality of names is defined on the fullname and is case-sensitive."""
- return isinstance(other, Name) and self.fullname == other.fullname
+ try:
+ return self.fullname == other.fullname
+ except AttributeError:
+ return False
@property
def fullname(self):
@@ -449,7 +478,7 @@ class DecimalLogicalSchema(LogicalSchema):
super(DecimalLogicalSchema, self).__init__('decimal')
-class Field(CanonicalPropertiesMixin):
+class Field(CanonicalPropertiesMixin, EqualByJsonMixin):
def __init__(self, type, name, has_default, default=None,
order=None, names=None, doc=None, other_props=None):
# Ensure valid ctor args
@@ -526,16 +555,12 @@ class Field(CanonicalPropertiesMixin):
return to_dump
- def __eq__(self, that):
- to_cmp = json.loads(str(self))
- return to_cmp == json.loads(str(that))
-
#
# Primitive Types
#
-class PrimitiveSchema(Schema):
+class PrimitiveSchema(EqualByPropsMixin, Schema):
"""Valid primitive types are in PRIMITIVE_TYPES."""
_validators = {
@@ -589,9 +614,6 @@ class PrimitiveSchema(Schema):
validator = self._validators.get(self.type, lambda x: False)
return self if validator(datum) else None
- def __eq__(self, that):
- return self.props == that.props
-
#
# Decimal Bytes Type
#
@@ -615,14 +637,11 @@ class BytesDecimalSchema(PrimitiveSchema, DecimalLogicalSchema):
"""Return self if datum is a Decimal object, else None."""
return self if isinstance(datum, decimal.Decimal) else None
- def __eq__(self, that):
- return self.props == that.props
-
#
# Complex Types (non-recursive)
#
-class FixedSchema(NamedSchema):
+class FixedSchema(EqualByPropsMixin, NamedSchema):
def __init__(self, name, namespace, size, names=None, other_props=None):
# Ensure valid ctor args
if not isinstance(size, int) or size < 0:
@@ -665,9 +684,6 @@ class FixedSchema(NamedSchema):
"""Return self if datum is a valid representation of this schema, else None."""
return self if isinstance(datum, bytes) and len(datum) == self.size else None
- def __eq__(self, that):
- return self.props == that.props
-
#
# Decimal Fixed Type
#
@@ -692,11 +708,8 @@ class FixedDecimalSchema(FixedSchema, DecimalLogicalSchema):
"""Return self if datum is a Decimal object, else None."""
return self if isinstance(datum, decimal.Decimal) else None
- def __eq__(self, that):
- return self.props == that.props
-
-class EnumSchema(NamedSchema):
+class EnumSchema(EqualByPropsMixin, NamedSchema):
def __init__(self, name, namespace, symbols, names=None, doc=None, other_props=None, validate_enum_symbols=True):
"""
@arg validate_enum_symbols: If False, will allow enum symbols that are not valid Avro names.
@@ -756,15 +769,12 @@ class EnumSchema(NamedSchema):
"""Return self if datum is a valid member of this Enum, else None."""
return self if datum in self.symbols else None
- def __eq__(self, that):
- return self.props == that.props
-
#
# Complex Types (recursive)
#
-class ArraySchema(Schema):
+class ArraySchema(EqualByJsonMixin, Schema):
def __init__(self, items, names=None, other_props=None):
# Call parent ctor
Schema.__init__(self, 'array', other_props)
@@ -814,12 +824,8 @@ class ArraySchema(Schema):
"""Return self if datum is a valid representation of this schema, else None."""
return self if isinstance(datum, list) else None
- def __eq__(self, that):
- to_cmp = json.loads(str(self))
- return to_cmp == json.loads(str(that))
-
-class MapSchema(Schema):
+class MapSchema(EqualByJsonMixin, Schema):
def __init__(self, values, names=None, other_props=None):
# Call parent ctor
Schema.__init__(self, 'map', other_props)
@@ -868,12 +874,8 @@ class MapSchema(Schema):
"""Return self if datum is a valid representation of this schema, else None."""
return self if isinstance(datum, dict) and all(isinstance(key, str) for key in datum) else None
- def __eq__(self, that):
- to_cmp = json.loads(str(self))
- return to_cmp == json.loads(str(that))
-
-class UnionSchema(Schema):
+class UnionSchema(EqualByJsonMixin, Schema):
"""
names is a dictionary of schema objects
"""
@@ -938,10 +940,6 @@ class UnionSchema(Schema):
if branch.validate(datum) is not None:
return branch
- def __eq__(self, that):
- to_cmp = json.loads(str(self))
- return to_cmp == json.loads(str(that))
-
class ErrorUnionSchema(UnionSchema):
def __init__(self, schemas, names=None):
@@ -961,7 +959,7 @@ class ErrorUnionSchema(UnionSchema):
return to_dump
-class RecordSchema(NamedSchema):
+class RecordSchema(EqualByJsonMixin, NamedSchema):
@staticmethod
def make_field_objects(field_data, names):
"""We're going to need to make message parameters too."""
@@ -1082,10 +1080,6 @@ class RecordSchema(NamedSchema):
"""Return self if datum is a valid representation of this schema, else None"""
return self if isinstance(datum, dict) and {f.name for f in self.fields}.issuperset(datum.keys()) else None
- def __eq__(self, that):
- to_cmp = json.loads(str(self))
- return to_cmp == json.loads(str(that))
-
#
# Date Type
@@ -1103,9 +1097,6 @@ class DateSchema(LogicalSchema, PrimitiveSchema):
"""Return self if datum is a valid date object, else None."""
return self if isinstance(datum, datetime.date) else None
- def __eq__(self, that):
- return self.props == that.props
-
#
# time-millis Type
#
@@ -1123,9 +1114,6 @@ class TimeMillisSchema(LogicalSchema, PrimitiveSchema):
"""Return self if datum is a valid representation of this schema, else None."""
return self if isinstance(datum, datetime.time) else None
- def __eq__(self, that):
- return self.props == that.props
-
#
# time-micros Type
#
@@ -1143,9 +1131,6 @@ class TimeMicrosSchema(LogicalSchema, PrimitiveSchema):
"""Return self if datum is a valid representation of this schema, else None."""
return self if isinstance(datum, datetime.time) else None
- def __eq__(self, that):
- return self.props == that.props
-
#
# timestamp-millis Type
#
@@ -1162,9 +1147,6 @@ class TimestampMillisSchema(LogicalSchema, PrimitiveSchema):
def validate(self, datum):
return self if isinstance(datum, datetime.datetime) and _is_timezone_aware_datetime(datum) else None
- def __eq__(self, that):
- return self.props == that.props
-
#
# timestamp-micros Type
#
@@ -1181,9 +1163,6 @@ class TimestampMicrosSchema(LogicalSchema, PrimitiveSchema):
def validate(self, datum):
return self if isinstance(datum, datetime.datetime) and _is_timezone_aware_datetime(datum) else None
- def __eq__(self, that):
- return self.props == that.props
-
#
# uuid Type
@@ -1208,9 +1187,6 @@ class UUIDSchema(LogicalSchema, PrimitiveSchema):
return self
- def __eq__(self, that):
- return self.props == that.props
-
#
# Module Methods
#
diff --git a/lang/py/avro/test/test_schema.py b/lang/py/avro/test/test_schema.py
index 6e5516a..b9f2887 100644
--- a/lang/py/avro/test/test_schema.py
+++ b/lang/py/avro/test/test_schema.py
@@ -597,7 +597,12 @@ class OtherAttributesTestCase(unittest.TestCase):
def check_attributes(self):
"""Other attributes and their types on a schema should be preserved."""
sch = self.test_schema.parse()
+ try:
+ self.assertNotEqual(sch, object(), "A schema is never equal to a non-schema instance.")
+ except AttributeError:
+ self.fail("Comparing a schema to a non-schema should be False, but not error.")
round_trip = avro.schema.parse(str(sch))
+ self.assertEqual(sch, round_trip, "A schema should be equal to another schema parsed from the same json.")
self.assertEqual(sch.other_props, round_trip.other_props,
"Properties were not preserved in a round-trip parse.")
self._check_props(sch.other_props)