You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@avro.apache.org by dk...@apache.org on 2018/12/11 18:28:07 UTC
[avro] branch master updated: AVRO-1777: Select best matching
record when writing a union in python (#95)
This is an automated email from the ASF dual-hosted git repository.
dkulp pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/avro.git
The following commit(s) were added to refs/heads/master by this push:
new beef866 AVRO-1777: Select best matching record when writing a union in python (#95)
beef866 is described below
commit beef86697a0dbd09d4d99a735f8a5afe37e5d976
Author: shiraeeshi <sh...@mail.ru>
AuthorDate: Wed Dec 12 00:28:03 2018 +0600
AVRO-1777: Select best matching record when writing a union in python (#95)
* make numeric schemas not valid for boolean datums
* fix formatting
* fix int and long schemas
* fix brackets
* fix brackets
* add GenericRecord type
---
lang/py/src/avro/io.py | 43 ++++++++++++++++++++++++++++++++++++-------
lang/py/test/test_io.py | 11 ++++++++++-
2 files changed, 46 insertions(+), 8 deletions(-)
diff --git a/lang/py/src/avro/io.py b/lang/py/src/avro/io.py
index b2fd2f9..63907e1 100644
--- a/lang/py/src/avro/io.py
+++ b/lang/py/src/avro/io.py
@@ -94,6 +94,10 @@ class SchemaResolutionException(schema.AvroException):
if readers_schema: fail_msg += "\nReader's Schema: %s" % pretty_readers
schema.AvroException.__init__(self, fail_msg)
+class RecordInitializationException(schema.AvroException):
+ def __init__(self, fail_msg):
+ schema.AvroException.__init__(self, fail_msg)
+
#
# Validate
#
@@ -110,14 +114,17 @@ def validate(expected_schema, datum):
elif schema_type == 'bytes':
return isinstance(datum, str)
elif schema_type == 'int':
- return ((isinstance(datum, int) or isinstance(datum, long))
- and INT_MIN_VALUE <= datum <= INT_MAX_VALUE)
+ return (((isinstance(datum, int) and not isinstance(datum, bool)) or
+ isinstance(datum, long)) and
+ INT_MIN_VALUE <= datum <= INT_MAX_VALUE)
elif schema_type == 'long':
- return ((isinstance(datum, int) or isinstance(datum, long))
- and LONG_MIN_VALUE <= datum <= LONG_MAX_VALUE)
+ return (((isinstance(datum, int) and not isinstance(datum, bool)) or
+ isinstance(datum, long)) and
+ LONG_MIN_VALUE <= datum <= LONG_MAX_VALUE)
elif schema_type in ['float', 'double']:
- return (isinstance(datum, int) or isinstance(datum, long)
- or isinstance(datum, float))
+ return (isinstance(datum, long) or
+ (isinstance(datum, int) and not isinstance(datum, bool)) or
+ isinstance(datum, float))
elif schema_type == 'fixed':
return isinstance(datum, str) and len(datum) == expected_schema.size
elif schema_type == 'enum':
@@ -132,6 +139,8 @@ def validate(expected_schema, datum):
[validate(expected_schema.values, v) for v in datum.values()])
elif schema_type in ['union', 'error_union']:
return True in [validate(s, datum) for s in expected_schema.schemas]
+ elif schema_type == 'record' and isinstance(datum, GenericRecord):
+ return expected_schema == datum.schema
elif schema_type in ['record', 'error', 'request']:
return (isinstance(datum, dict) and
False not in
@@ -683,7 +692,7 @@ class DatumReader(object):
"""
# schema resolution
readers_fields_dict = readers_schema.fields_dict
- read_record = {}
+ read_record = GenericRecord(readers_schema)
for field in writers_schema.fields:
readers_field = readers_fields_dict.get(field.name)
if readers_field is not None:
@@ -888,3 +897,23 @@ class DatumWriter(object):
"""
for field in writers_schema.fields:
self.write_data(field.type, datum.get(field.name), encoder)
+
+class GenericRecord(dict):
+
+ def __init__(self, record_schema, lst = []):
+ if (record_schema is None or
+ not isinstance(record_schema, schema.Schema)):
+ raise RecordInitializationException(
+ "Cannot initialize a record with schema: {sc}".format(sc = record_schema))
+ dict.__init__(self, lst)
+ self.schema = record_schema
+
+ def __eq__(self, other):
+ if other is None or not isinstance(other, dict):
+ return False
+ if not dict.__eq__(self, other):
+ return False
+ if isinstance(other, GenericRecord):
+ return self.schema == other.schema
+ else:
+ return True
diff --git a/lang/py/test/test_io.py b/lang/py/test/test_io.py
index 1e79d3e..d6e341a 100644
--- a/lang/py/test/test_io.py
+++ b/lang/py/test/test_io.py
@@ -39,6 +39,8 @@ SCHEMAS_TO_VALIDATE = (
('{"type": "array", "items": "long"}', [1, 3, 2]),
('{"type": "map", "values": "long"}', {'a': 1, 'b': 3, 'c': 2}),
('["string", "null", "long"]', None),
+ ('["double", "boolean"]', True),
+ ('["boolean", "double"]', True),
("""\
{"type": "record",
"name": "Test",
@@ -190,6 +192,13 @@ class TestIO(unittest.TestCase):
def test_round_trip(self):
print_test_name('TEST ROUND TRIP')
correct = 0
+ def are_equal(datum, round_trip_datum):
+ if datum != round_trip_datum:
+ return False
+ if type(datum) == bool:
+ return type(round_trip_datum) == bool
+ else:
+ return True
for example_schema, datum in SCHEMAS_TO_VALIDATE:
print 'Schema: %s' % example_schema
print 'Datum: %s' % datum
@@ -199,7 +208,7 @@ class TestIO(unittest.TestCase):
round_trip_datum = read_datum(writer, writers_schema)
print 'Round Trip Datum: %s' % round_trip_datum
- if datum == round_trip_datum: correct += 1
+ if are_equal(datum, round_trip_datum): correct += 1
self.assertEquals(correct, len(SCHEMAS_TO_VALIDATE))
#