You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@avro.apache.org by ko...@apache.org on 2021/05/21 13:11:22 UTC

[avro] branch master updated: AVRO-3104 Allow Comparing Schemas to non-Schemas (#1221)

This is an automated email from the ASF dual-hosted git repository.

kojiromike pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/avro.git


The following commit(s) were added to refs/heads/master by this push:
     new 72a0fef  AVRO-3104 Allow Comparing Schemas to non-Schemas (#1221)
72a0fef is described below

commit 72a0fef3e9247bd94621288fb2e6d88408e3c836
Author: Michael A. Smith <mi...@smith-li.com>
AuthorDate: Fri May 21 09:11:07 2021 -0400

    AVRO-3104 Allow Comparing Schemas to non-Schemas (#1221)
    
    Avro no longer crashes with an `AttributeError` when you test if a schema is equal to an object that is not a schema, as in `some_schema == not_a_schema`. The result is always `False`.
---
 lang/py/avro/schema.py           | 102 +++++++++++++++------------------------
 lang/py/avro/test/test_schema.py |   5 ++
 2 files changed, 44 insertions(+), 63 deletions(-)

diff --git a/lang/py/avro/schema.py b/lang/py/avro/schema.py
index 7999d0f..43531a3 100644
--- a/lang/py/avro/schema.py
+++ b/lang/py/avro/schema.py
@@ -148,7 +148,26 @@ def _is_timezone_aware_datetime(dt):
 # Base Classes
 #
 
-class CanonicalPropertiesMixin(object):
+class EqualByJsonMixin:
+    """Equal if the json serializations are equal."""
+    def __eq__(self, that):
+        try:
+            that_str = json.loads(str(that))
+        except json.decoder.JSONDecodeError:
+            return False
+        return json.loads(str(self)) == that_str
+
+
+class EqualByPropsMixin:
+    """Equal if the props are equal."""
+    def __eq__(self, that):
+        try:
+            return self.props == that.props
+        except AttributeError:
+            return False
+
+
+class CanonicalPropertiesMixin:
     """A Mixin that provides canonical properties to Schema and Field types."""
     @property
     def canonical_properties(self):
@@ -253,6 +272,13 @@ class Schema(abc.ABC, CanonicalPropertiesMixin):
         # The separators eliminate whitespace around commas and colons.
         return json.dumps(self.to_canonical_json(), separators=(",", ":"))
 
+    @abc.abstractmethod
+    def __eq__(self, that):
+        """
+        Determines how two schema are compared.
+        Consider the mixins EqualByPropsMixin and EqualByJsonMixin
+        """
+
 
 class Name:
     """Class to describe Avro name."""
@@ -301,7 +327,10 @@ class Name:
 
     def __eq__(self, other):
         """Equality of names is defined on the fullname and is case-sensitive."""
-        return isinstance(other, Name) and self.fullname == other.fullname
+        try:
+            return self.fullname == other.fullname
+        except AttributeError:
+            return False
 
     @property
     def fullname(self):
@@ -449,7 +478,7 @@ class DecimalLogicalSchema(LogicalSchema):
         super(DecimalLogicalSchema, self).__init__('decimal')
 
 
-class Field(CanonicalPropertiesMixin):
+class Field(CanonicalPropertiesMixin, EqualByJsonMixin):
     def __init__(self, type, name, has_default, default=None,
                  order=None, names=None, doc=None, other_props=None):
         # Ensure valid ctor args
@@ -526,16 +555,12 @@ class Field(CanonicalPropertiesMixin):
 
         return to_dump
 
-    def __eq__(self, that):
-        to_cmp = json.loads(str(self))
-        return to_cmp == json.loads(str(that))
-
 #
 # Primitive Types
 #
 
 
-class PrimitiveSchema(Schema):
+class PrimitiveSchema(EqualByPropsMixin, Schema):
     """Valid primitive types are in PRIMITIVE_TYPES."""
 
     _validators = {
@@ -589,9 +614,6 @@ class PrimitiveSchema(Schema):
         validator = self._validators.get(self.type, lambda x: False)
         return self if validator(datum) else None
 
-    def __eq__(self, that):
-        return self.props == that.props
-
 #
 # Decimal Bytes Type
 #
@@ -615,14 +637,11 @@ class BytesDecimalSchema(PrimitiveSchema, DecimalLogicalSchema):
         """Return self if datum is a Decimal object, else None."""
         return self if isinstance(datum, decimal.Decimal) else None
 
-    def __eq__(self, that):
-        return self.props == that.props
-
 
 #
 # Complex Types (non-recursive)
 #
-class FixedSchema(NamedSchema):
+class FixedSchema(EqualByPropsMixin, NamedSchema):
     def __init__(self, name, namespace, size, names=None, other_props=None):
         # Ensure valid ctor args
         if not isinstance(size, int) or size < 0:
@@ -665,9 +684,6 @@ class FixedSchema(NamedSchema):
         """Return self if datum is a valid representation of this schema, else None."""
         return self if isinstance(datum, bytes) and len(datum) == self.size else None
 
-    def __eq__(self, that):
-        return self.props == that.props
-
 #
 # Decimal Fixed Type
 #
@@ -692,11 +708,8 @@ class FixedDecimalSchema(FixedSchema, DecimalLogicalSchema):
         """Return self if datum is a Decimal object, else None."""
         return self if isinstance(datum, decimal.Decimal) else None
 
-    def __eq__(self, that):
-        return self.props == that.props
-
 
-class EnumSchema(NamedSchema):
+class EnumSchema(EqualByPropsMixin, NamedSchema):
     def __init__(self, name, namespace, symbols, names=None, doc=None, other_props=None, validate_enum_symbols=True):
         """
         @arg validate_enum_symbols: If False, will allow enum symbols that are not valid Avro names.
@@ -756,15 +769,12 @@ class EnumSchema(NamedSchema):
         """Return self if datum is a valid member of this Enum, else None."""
         return self if datum in self.symbols else None
 
-    def __eq__(self, that):
-        return self.props == that.props
-
 #
 # Complex Types (recursive)
 #
 
 
-class ArraySchema(Schema):
+class ArraySchema(EqualByJsonMixin, Schema):
     def __init__(self, items, names=None, other_props=None):
         # Call parent ctor
         Schema.__init__(self, 'array', other_props)
@@ -814,12 +824,8 @@ class ArraySchema(Schema):
         """Return self if datum is a valid representation of this schema, else None."""
         return self if isinstance(datum, list) else None
 
-    def __eq__(self, that):
-        to_cmp = json.loads(str(self))
-        return to_cmp == json.loads(str(that))
-
 
-class MapSchema(Schema):
+class MapSchema(EqualByJsonMixin, Schema):
     def __init__(self, values, names=None, other_props=None):
         # Call parent ctor
         Schema.__init__(self, 'map', other_props)
@@ -868,12 +874,8 @@ class MapSchema(Schema):
         """Return self if datum is a valid representation of this schema, else None."""
         return self if isinstance(datum, dict) and all(isinstance(key, str) for key in datum) else None
 
-    def __eq__(self, that):
-        to_cmp = json.loads(str(self))
-        return to_cmp == json.loads(str(that))
-
 
-class UnionSchema(Schema):
+class UnionSchema(EqualByJsonMixin, Schema):
     """
     names is a dictionary of schema objects
     """
@@ -938,10 +940,6 @@ class UnionSchema(Schema):
             if branch.validate(datum) is not None:
                 return branch
 
-    def __eq__(self, that):
-        to_cmp = json.loads(str(self))
-        return to_cmp == json.loads(str(that))
-
 
 class ErrorUnionSchema(UnionSchema):
     def __init__(self, schemas, names=None):
@@ -961,7 +959,7 @@ class ErrorUnionSchema(UnionSchema):
         return to_dump
 
 
-class RecordSchema(NamedSchema):
+class RecordSchema(EqualByJsonMixin, NamedSchema):
     @staticmethod
     def make_field_objects(field_data, names):
         """We're going to need to make message parameters too."""
@@ -1082,10 +1080,6 @@ class RecordSchema(NamedSchema):
         """Return self if datum is a valid representation of this schema, else None"""
         return self if isinstance(datum, dict) and {f.name for f in self.fields}.issuperset(datum.keys()) else None
 
-    def __eq__(self, that):
-        to_cmp = json.loads(str(self))
-        return to_cmp == json.loads(str(that))
-
 
 #
 # Date Type
@@ -1103,9 +1097,6 @@ class DateSchema(LogicalSchema, PrimitiveSchema):
         """Return self if datum is a valid date object, else None."""
         return self if isinstance(datum, datetime.date) else None
 
-    def __eq__(self, that):
-        return self.props == that.props
-
 #
 # time-millis Type
 #
@@ -1123,9 +1114,6 @@ class TimeMillisSchema(LogicalSchema, PrimitiveSchema):
         """Return self if datum is a valid representation of this schema, else None."""
         return self if isinstance(datum, datetime.time) else None
 
-    def __eq__(self, that):
-        return self.props == that.props
-
 #
 # time-micros Type
 #
@@ -1143,9 +1131,6 @@ class TimeMicrosSchema(LogicalSchema, PrimitiveSchema):
         """Return self if datum is a valid representation of this schema, else None."""
         return self if isinstance(datum, datetime.time) else None
 
-    def __eq__(self, that):
-        return self.props == that.props
-
 #
 # timestamp-millis Type
 #
@@ -1162,9 +1147,6 @@ class TimestampMillisSchema(LogicalSchema, PrimitiveSchema):
     def validate(self, datum):
         return self if isinstance(datum, datetime.datetime) and _is_timezone_aware_datetime(datum) else None
 
-    def __eq__(self, that):
-        return self.props == that.props
-
 #
 # timestamp-micros Type
 #
@@ -1181,9 +1163,6 @@ class TimestampMicrosSchema(LogicalSchema, PrimitiveSchema):
     def validate(self, datum):
         return self if isinstance(datum, datetime.datetime) and _is_timezone_aware_datetime(datum) else None
 
-    def __eq__(self, that):
-        return self.props == that.props
-
 
 #
 # uuid Type
@@ -1208,9 +1187,6 @@ class UUIDSchema(LogicalSchema, PrimitiveSchema):
 
         return self
 
-    def __eq__(self, that):
-        return self.props == that.props
-
 #
 # Module Methods
 #
diff --git a/lang/py/avro/test/test_schema.py b/lang/py/avro/test/test_schema.py
index 6e5516a..b9f2887 100644
--- a/lang/py/avro/test/test_schema.py
+++ b/lang/py/avro/test/test_schema.py
@@ -597,7 +597,12 @@ class OtherAttributesTestCase(unittest.TestCase):
     def check_attributes(self):
         """Other attributes and their types on a schema should be preserved."""
         sch = self.test_schema.parse()
+        try:
+            self.assertNotEqual(sch, object(), "A schema is never equal to a non-schema instance.")
+        except AttributeError:
+            self.fail("Comparing a schema to a non-schema should be False, but not error.")
         round_trip = avro.schema.parse(str(sch))
+        self.assertEqual(sch, round_trip, "A schema should be equal to another schema parsed from the same json.")
         self.assertEqual(sch.other_props, round_trip.other_props,
                          "Properties were not preserved in a round-trip parse.")
         self._check_props(sch.other_props)