You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@avro.apache.org by ko...@apache.org on 2020/05/01 13:01:35 UTC
[avro] branch master updated: AVRO-2613: Fix Linter Complaints
(#871)
This is an automated email from the ASF dual-hosted git repository.
kojiromike pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/avro.git
The following commit(s) were added to refs/heads/master by this push:
new 02f12b8 AVRO-2613: Fix Linter Complaints (#871)
02f12b8 is described below
commit 02f12b87b7a8254b5169c479d44d0e1a5952f1c7
Author: Michael A. Smith <mi...@smith-li.com>
AuthorDate: Fri May 1 09:01:21 2020 -0400
AVRO-2613: Fix Linter Complaints (#871)
* AVRO-2613: Whitespace Fixes (E1* category)
* AVRO-2613: Whitespace (E2* category)
* AVRO-2613: Whitespace (E3* category)
* AVRO-2613: Whitespace (E5* category)
* AVRO-2613: Whitespace (E7* Category)
* AVRO-2613: Whitespace (W6* Category)
---
.editorconfig | 2 +-
lang/py/avro/codecs.py | 244 ++--
lang/py/avro/datafile.py | 557 ++++----
lang/py/avro/io.py | 1979 +++++++++++++-------------
lang/py/avro/ipc.py | 784 +++++-----
lang/py/avro/protocol.py | 399 +++---
lang/py/avro/schema.py | 1729 +++++++++++-----------
lang/py/avro/test/av_bench.py | 20 +-
lang/py/avro/test/gen_interop_data.py | 60 +-
lang/py/avro/test/mock_tether_parent.py | 103 +-
lang/py/avro/test/sample_http_client.py | 72 +-
lang/py/avro/test/sample_http_server.py | 53 +-
lang/py/avro/test/test_datafile.py | 312 ++--
lang/py/avro/test/test_datafile_interop.py | 45 +-
lang/py/avro/test/test_init.py | 8 +-
lang/py/avro/test/test_io.py | 563 ++++----
lang/py/avro/test/test_ipc.py | 17 +-
lang/py/avro/test/test_protocol.py | 525 +++----
lang/py/avro/test/test_schema.py | 930 ++++++------
lang/py/avro/test/test_script.py | 18 +-
lang/py/avro/test/test_tether_task.py | 156 +-
lang/py/avro/test/test_tether_task_runner.py | 306 ++--
lang/py/avro/test/test_tether_word_count.py | 190 +--
lang/py/avro/test/txsample_http_client.py | 96 +-
lang/py/avro/test/txsample_http_server.py | 28 +-
lang/py/avro/test/word_count_task.py | 100 +-
lang/py/avro/tether/tether_task.py | 813 ++++++-----
lang/py/avro/tether/tether_task_runner.py | 317 +++--
lang/py/avro/tether/util.py | 16 +-
lang/py/avro/timezones.py | 26 +-
lang/py/avro/tool.py | 218 +--
lang/py/avro/txipc.py | 364 ++---
lang/py/scripts/avro | 77 +-
lang/py/setup.py | 44 +-
lang/py/tox.ini | 1 -
35 files changed, 5686 insertions(+), 5486 deletions(-)
diff --git a/.editorconfig b/.editorconfig
index 6e7532c..b2d8a7c 100644
--- a/.editorconfig
+++ b/.editorconfig
@@ -36,4 +36,4 @@ trim_trailing_whitespace=true
[*.py]
indent_style = space
-indent_size = 2
+indent_size = 4
diff --git a/lang/py/avro/codecs.py b/lang/py/avro/codecs.py
index 772db3e..3c0f188 100644
--- a/lang/py/avro/codecs.py
+++ b/lang/py/avro/codecs.py
@@ -46,159 +46,159 @@ STRUCT_CRC32 = Struct('>I') # big-endian unsigned int
try:
- import bz2
- has_bzip2 = True
+ import bz2
+ has_bzip2 = True
except ImportError:
- has_bzip2 = False
+ has_bzip2 = False
try:
- import snappy
- has_snappy = True
+ import snappy
+ has_snappy = True
except ImportError:
- has_snappy = False
+ has_snappy = False
try:
- import zstandard as zstd
- has_zstandard = True
+ import zstandard as zstd
+ has_zstandard = True
except ImportError:
- has_zstandard = False
+ has_zstandard = False
class Codec:
- """Abstract base class for all Avro codec classes."""
- __metaclass__ = ABCMeta
+ """Abstract base class for all Avro codec classes."""
+ __metaclass__ = ABCMeta
- @abstractmethod
- def compress(self, data):
- """Compress the passed data.
+ @abstractmethod
+ def compress(self, data):
+ """Compress the passed data.
- :param data: a byte string to be compressed
- :type data: str
+ :param data: a byte string to be compressed
+ :type data: str
- :rtype: tuple
- :return: compressed data and its length
- """
- pass
+ :rtype: tuple
+ :return: compressed data and its length
+ """
+ pass
- @abstractmethod
- def decompress(self, readers_decoder):
- """Read compressed data via the passed BinaryDecoder and decompress it.
+ @abstractmethod
+ def decompress(self, readers_decoder):
+ """Read compressed data via the passed BinaryDecoder and decompress it.
- :param readers_decoder: a BinaryDecoder object currently being used for
- reading an object container file
- :type readers_decoder: avro.io.BinaryDecoder
+ :param readers_decoder: a BinaryDecoder object currently being used for
+ reading an object container file
+ :type readers_decoder: avro.io.BinaryDecoder
- :rtype: avro.io.BinaryDecoder
- :return: a newly instantiated BinaryDecoder object that contains the
- decompressed data which is wrapped by a StringIO
- """
- pass
+ :rtype: avro.io.BinaryDecoder
+ :return: a newly instantiated BinaryDecoder object that contains the
+ decompressed data which is wrapped by a StringIO
+ """
+ pass
class NullCodec(Codec):
- def compress(self, data):
- return data, len(data)
+ def compress(self, data):
+ return data, len(data)
- def decompress(self, readers_decoder):
- readers_decoder.skip_long()
- return readers_decoder
+ def decompress(self, readers_decoder):
+ readers_decoder.skip_long()
+ return readers_decoder
class DeflateCodec(Codec):
- def compress(self, data):
- # The first two characters and last character are zlib
- # wrappers around deflate data.
- compressed_data = zlib.compress(data)[2:-1]
- return compressed_data, len(compressed_data)
-
- def decompress(self, readers_decoder):
- # Compressed data is stored as (length, data), which
- # corresponds to how the "bytes" type is encoded.
- data = readers_decoder.read_bytes()
- # -15 is the log of the window size; negative indicates
- # "raw" (no zlib headers) decompression. See zlib.h.
- uncompressed = zlib.decompress(data, -15)
- return avro.io.BinaryDecoder(io.BytesIO(uncompressed))
-
-
-if has_bzip2:
- class BZip2Codec(Codec):
def compress(self, data):
- compressed_data = bz2.compress(data)
- return compressed_data, len(compressed_data)
+ # The first two characters and last character are zlib
+ # wrappers around deflate data.
+ compressed_data = zlib.compress(data)[2:-1]
+ return compressed_data, len(compressed_data)
def decompress(self, readers_decoder):
- length = readers_decoder.read_long()
- data = readers_decoder.read(length)
- uncompressed = bz2.decompress(data)
- return avro.io.BinaryDecoder(io.BytesIO(uncompressed))
+ # Compressed data is stored as (length, data), which
+ # corresponds to how the "bytes" type is encoded.
+ data = readers_decoder.read_bytes()
+ # -15 is the log of the window size; negative indicates
+ # "raw" (no zlib headers) decompression. See zlib.h.
+ uncompressed = zlib.decompress(data, -15)
+ return avro.io.BinaryDecoder(io.BytesIO(uncompressed))
-if has_snappy:
- class SnappyCodec(Codec):
- def compress(self, data):
- compressed_data = snappy.compress(data)
- # A 4-byte, big-endian CRC32 checksum
- compressed_data += STRUCT_CRC32.pack(crc32(data) & 0xffffffff)
- return compressed_data, len(compressed_data)
+if has_bzip2:
+ class BZip2Codec(Codec):
+ def compress(self, data):
+ compressed_data = bz2.compress(data)
+ return compressed_data, len(compressed_data)
+
+ def decompress(self, readers_decoder):
+ length = readers_decoder.read_long()
+ data = readers_decoder.read(length)
+ uncompressed = bz2.decompress(data)
+ return avro.io.BinaryDecoder(io.BytesIO(uncompressed))
- def decompress(self, readers_decoder):
- # Compressed data includes a 4-byte CRC32 checksum
- length = readers_decoder.read_long()
- data = readers_decoder.read(length - 4)
- uncompressed = snappy.decompress(data)
- checksum = readers_decoder.read(4)
- self.check_crc32(uncompressed, checksum)
- return avro.io.BinaryDecoder(io.BytesIO(uncompressed))
- def check_crc32(self, bytes, checksum):
- checksum = STRUCT_CRC32.unpack(checksum)[0];
- if crc32(bytes) & 0xffffffff != checksum:
- raise schema.AvroException("Checksum failure")
+if has_snappy:
+ class SnappyCodec(Codec):
+ def compress(self, data):
+ compressed_data = snappy.compress(data)
+ # A 4-byte, big-endian CRC32 checksum
+ compressed_data += STRUCT_CRC32.pack(crc32(data) & 0xffffffff)
+ return compressed_data, len(compressed_data)
+
+ def decompress(self, readers_decoder):
+ # Compressed data includes a 4-byte CRC32 checksum
+ length = readers_decoder.read_long()
+ data = readers_decoder.read(length - 4)
+ uncompressed = snappy.decompress(data)
+ checksum = readers_decoder.read(4)
+ self.check_crc32(uncompressed, checksum)
+ return avro.io.BinaryDecoder(io.BytesIO(uncompressed))
+
+ def check_crc32(self, bytes, checksum):
+ checksum = STRUCT_CRC32.unpack(checksum)[0]
+ if crc32(bytes) & 0xffffffff != checksum:
+ raise schema.AvroException("Checksum failure")
if has_zstandard:
- class ZstandardCodec(Codec):
- def compress(self, data):
- compressed_data = zstd.ZstdCompressor().compress(data)
- return compressed_data, len(compressed_data)
-
- def decompress(self, readers_decoder):
- length = readers_decoder.read_long()
- data = readers_decoder.read(length)
- uncompressed = bytearray()
- dctx = zstd.ZstdDecompressor()
- with dctx.stream_reader(io.BytesIO(data)) as reader:
- while True:
- chunk = reader.read(16384)
- if not chunk:
- break
- uncompressed.extend(chunk)
- return avro.io.BinaryDecoder(io.BytesIO(uncompressed))
+ class ZstandardCodec(Codec):
+ def compress(self, data):
+ compressed_data = zstd.ZstdCompressor().compress(data)
+ return compressed_data, len(compressed_data)
+
+ def decompress(self, readers_decoder):
+ length = readers_decoder.read_long()
+ data = readers_decoder.read(length)
+ uncompressed = bytearray()
+ dctx = zstd.ZstdDecompressor()
+ with dctx.stream_reader(io.BytesIO(data)) as reader:
+ while True:
+ chunk = reader.read(16384)
+ if not chunk:
+ break
+ uncompressed.extend(chunk)
+ return avro.io.BinaryDecoder(io.BytesIO(uncompressed))
class Codecs(object):
- @staticmethod
- def get_codec(codec_name):
- codec_name = codec_name.lower()
- if codec_name == "null":
- return NullCodec()
- elif codec_name == "deflate":
- return DeflateCodec()
- elif codec_name == "bzip2" and has_bzip2:
- return BZip2Codec()
- elif codec_name == "snappy" and has_snappy:
- return SnappyCodec()
- elif codec_name == "zstandard" and has_zstandard:
- return ZstandardCodec()
- else:
- raise ValueError("Unsupported codec: %r" % codec_name)
-
- @staticmethod
- def supported_codec_names():
- codec_names = ['null', 'deflate']
- if has_bzip2:
- codec_names.append('bzip2')
- if has_snappy:
- codec_names.append('snappy')
- if has_zstandard:
- codec_names.append('zstandard')
- return codec_names
+ @staticmethod
+ def get_codec(codec_name):
+ codec_name = codec_name.lower()
+ if codec_name == "null":
+ return NullCodec()
+ elif codec_name == "deflate":
+ return DeflateCodec()
+ elif codec_name == "bzip2" and has_bzip2:
+ return BZip2Codec()
+ elif codec_name == "snappy" and has_snappy:
+ return SnappyCodec()
+ elif codec_name == "zstandard" and has_zstandard:
+ return ZstandardCodec()
+ else:
+ raise ValueError("Unsupported codec: %r" % codec_name)
+
+ @staticmethod
+ def supported_codec_names():
+ codec_names = ['null', 'deflate']
+ if has_bzip2:
+ codec_names.append('bzip2')
+ if has_snappy:
+ codec_names.append('snappy')
+ if has_zstandard:
+ codec_names.append('zstandard')
+ return codec_names
diff --git a/lang/py/avro/datafile.py b/lang/py/avro/datafile.py
index 6d9222b..14e332d 100644
--- a/lang/py/avro/datafile.py
+++ b/lang/py/avro/datafile.py
@@ -37,7 +37,7 @@ VERSION = 1
MAGIC = bytes(b'Obj' + bytearray([VERSION]))
MAGIC_SIZE = len(MAGIC)
SYNC_SIZE = 16
-SYNC_INTERVAL = 4000 * SYNC_SIZE # TODO(hammer): make configurable
+SYNC_INTERVAL = 4000 * SYNC_SIZE # TODO(hammer): make configurable
META_SCHEMA = avro.schema.parse("""\
{"type": "record", "name": "org.apache.avro.file.Header",
"fields" : [
@@ -48,7 +48,7 @@ META_SCHEMA = avro.schema.parse("""\
NULL_CODEC = 'null'
VALID_CODECS = Codecs.supported_codec_names()
-VALID_ENCODINGS = ['binary'] # not used yet
+VALID_ENCODINGS = ['binary'] # not used yet
CODEC_KEY = "avro.codec"
SCHEMA_KEY = "avro.schema"
@@ -57,294 +57,299 @@ SCHEMA_KEY = "avro.schema"
# Exceptions
#
+
class DataFileException(avro.schema.AvroException):
- """
- Raised when there's a problem reading or writing file object containers.
- """
- def __init__(self, fail_msg):
- avro.schema.AvroException.__init__(self, fail_msg)
+ """
+ Raised when there's a problem reading or writing file object containers.
+ """
+
+ def __init__(self, fail_msg):
+ avro.schema.AvroException.__init__(self, fail_msg)
#
# Write Path
#
-class _DataFile(object):
- """Mixin for methods common to both reading and writing."""
-
- block_count = 0
- _meta = None
- _sync_marker = None
-
- def __enter__(self):
- return self
-
- def __exit__(self, type, value, traceback):
- # Perform a close if there's no exception
- if type is None:
- self.close()
-
- def get_meta(self, key):
- return self.meta.get(key)
-
- def set_meta(self, key, val):
- self.meta[key] = val
- @property
- def sync_marker(self):
- return self._sync_marker
-
- @property
- def meta(self):
- """Read-only dictionary of metadata for this datafile."""
- if self._meta is None:
- self._meta = {}
- return self._meta
-
- @property
- def codec(self):
- """Meta are stored as bytes, but codec is returned as a string."""
- try:
- return self.get_meta(CODEC_KEY).decode()
- except AttributeError:
- return "null"
-
- @codec.setter
- def codec(self, value):
- """Meta are stored as bytes, but codec is set as a string."""
- if value not in VALID_CODECS:
- raise DataFileException("Unknown codec: {!r}".format(value))
- self.set_meta(CODEC_KEY, value.encode())
-
- @property
- def schema(self):
- """Meta are stored as bytes, but schema is returned as a string."""
- return self.get_meta(SCHEMA_KEY).decode()
-
- @schema.setter
- def schema(self, value):
- """Meta are stored as bytes, but schema is set as a string."""
- self.set_meta(SCHEMA_KEY, value.encode())
+class _DataFile(object):
+ """Mixin for methods common to both reading and writing."""
+
+ block_count = 0
+ _meta = None
+ _sync_marker = None
+
+ def __enter__(self):
+ return self
+
+ def __exit__(self, type, value, traceback):
+ # Perform a close if there's no exception
+ if type is None:
+ self.close()
+
+ def get_meta(self, key):
+ return self.meta.get(key)
+
+ def set_meta(self, key, val):
+ self.meta[key] = val
+
+ @property
+ def sync_marker(self):
+ return self._sync_marker
+
+ @property
+ def meta(self):
+ """Read-only dictionary of metadata for this datafile."""
+ if self._meta is None:
+ self._meta = {}
+ return self._meta
+
+ @property
+ def codec(self):
+ """Meta are stored as bytes, but codec is returned as a string."""
+ try:
+ return self.get_meta(CODEC_KEY).decode()
+ except AttributeError:
+ return "null"
+
+ @codec.setter
+ def codec(self, value):
+ """Meta are stored as bytes, but codec is set as a string."""
+ if value not in VALID_CODECS:
+ raise DataFileException("Unknown codec: {!r}".format(value))
+ self.set_meta(CODEC_KEY, value.encode())
+
+ @property
+ def schema(self):
+ """Meta are stored as bytes, but schema is returned as a string."""
+ return self.get_meta(SCHEMA_KEY).decode()
+
+ @schema.setter
+ def schema(self, value):
+ """Meta are stored as bytes, but schema is set as a string."""
+ self.set_meta(SCHEMA_KEY, value.encode())
class DataFileWriter(_DataFile):
- # TODO(hammer): make 'encoder' a metadata property
- def __init__(self, writer, datum_writer, writers_schema=None, codec=NULL_CODEC):
- """
- If the schema is not present, presume we're appending.
+ # TODO(hammer): make 'encoder' a metadata property
+ def __init__(self, writer, datum_writer, writers_schema=None, codec=NULL_CODEC):
+ """
+ If the schema is not present, presume we're appending.
+
+ @param writer: File-like object to write into.
+ """
+ self._writer = writer
+ self._encoder = avro.io.BinaryEncoder(writer)
+ self._datum_writer = datum_writer
+ self._buffer_writer = io.BytesIO()
+ self._buffer_encoder = avro.io.BinaryEncoder(self._buffer_writer)
+ self.block_count = 0
+ self._header_written = False
+
+ if writers_schema is not None:
+ self._sync_marker = generate_sixteen_random_bytes()
+ self.codec = codec
+ self.schema = str(writers_schema)
+ self.datum_writer.writers_schema = writers_schema
+ else:
+ # open writer for reading to collect metadata
+ dfr = DataFileReader(writer, avro.io.DatumReader())
+
+ # TODO(hammer): collect arbitrary metadata
+ # collect metadata
+ self._sync_marker = dfr.sync_marker
+ self.codec = dfr.codec
+
+ # get schema used to write existing file
+ self.schema = schema_from_file = dfr.schema
+ self.datum_writer.writers_schema = avro.schema.parse(schema_from_file)
+
+ # seek to the end of the file and prepare for writing
+ writer.seek(0, 2)
+ self._header_written = True
+
+ # read-only properties
+ writer = property(lambda self: self._writer)
+ encoder = property(lambda self: self._encoder)
+ datum_writer = property(lambda self: self._datum_writer)
+ buffer_writer = property(lambda self: self._buffer_writer)
+ buffer_encoder = property(lambda self: self._buffer_encoder)
+
+ def _write_header(self):
+ header = {'magic': MAGIC,
+ 'meta': self.meta,
+ 'sync': self.sync_marker}
+ self.datum_writer.write_data(META_SCHEMA, header, self.encoder)
+ self._header_written = True
+
+ @property
+ def codec(self):
+ """Meta are stored as bytes, but codec is returned as a string."""
+ return self.get_meta(CODEC_KEY).decode()
+
+ @codec.setter
+ def codec(self, value):
+ """Meta are stored as bytes, but codec is set as a string."""
+ if value not in VALID_CODECS:
+ raise DataFileException("Unknown codec: {!r}".format(value))
+ self.set_meta(CODEC_KEY, value.encode())
+
+ # TODO(hammer): make a schema for blocks and use datum_writer
+ def _write_block(self):
+ if not self._header_written:
+ self._write_header()
+
+ if self.block_count > 0:
+ # write number of items in block
+ self.encoder.write_long(self.block_count)
+
+ # write block contents
+ uncompressed_data = self.buffer_writer.getvalue()
+ codec = Codecs.get_codec(self.codec)
+ compressed_data, compressed_data_length = codec.compress(uncompressed_data)
+
+ # Write length of block
+ self.encoder.write_long(compressed_data_length)
+
+ # Write block
+ self.writer.write(compressed_data)
+
+ # write sync marker
+ self.writer.write(self.sync_marker)
+
+ # reset buffer
+ self.buffer_writer.truncate(0)
+ self.buffer_writer.seek(0)
+ self.block_count = 0
+
+ def append(self, datum):
+ """Append a datum to the file."""
+ self.datum_writer.write(datum, self.buffer_encoder)
+ self.block_count += 1
+
+ # if the data to write is larger than the sync interval, write the block
+ if self.buffer_writer.tell() >= SYNC_INTERVAL:
+ self._write_block()
+
+ def sync(self):
+ """
+ Return the current position as a value that may be passed to
+ DataFileReader.seek(long). Forces the end of the current block,
+ emitting a synchronization marker.
+ """
+ self._write_block()
+ return self.writer.tell()
+
+ def flush(self):
+ """Flush the current state of the file, including metadata."""
+ self._write_block()
+ self.writer.flush()
+
+ def close(self):
+ """Close the file."""
+ self.flush()
+ self.writer.close()
- @param writer: File-like object to write into.
- """
- self._writer = writer
- self._encoder = avro.io.BinaryEncoder(writer)
- self._datum_writer = datum_writer
- self._buffer_writer = io.BytesIO()
- self._buffer_encoder = avro.io.BinaryEncoder(self._buffer_writer)
- self.block_count = 0
- self._header_written = False
-
- if writers_schema is not None:
- self._sync_marker = generate_sixteen_random_bytes()
- self.codec = codec
- self.schema = str(writers_schema)
- self.datum_writer.writers_schema = writers_schema
- else:
- # open writer for reading to collect metadata
- dfr = DataFileReader(writer, avro.io.DatumReader())
-
- # TODO(hammer): collect arbitrary metadata
- # collect metadata
- self._sync_marker = dfr.sync_marker
- self.codec = dfr.codec
-
- # get schema used to write existing file
- self.schema = schema_from_file = dfr.schema
- self.datum_writer.writers_schema = avro.schema.parse(schema_from_file)
-
- # seek to the end of the file and prepare for writing
- writer.seek(0, 2)
- self._header_written = True
-
- # read-only properties
- writer = property(lambda self: self._writer)
- encoder = property(lambda self: self._encoder)
- datum_writer = property(lambda self: self._datum_writer)
- buffer_writer = property(lambda self: self._buffer_writer)
- buffer_encoder = property(lambda self: self._buffer_encoder)
-
- def _write_header(self):
- header = {'magic': MAGIC,
- 'meta': self.meta,
- 'sync': self.sync_marker}
- self.datum_writer.write_data(META_SCHEMA, header, self.encoder)
- self._header_written = True
-
- @property
- def codec(self):
- """Meta are stored as bytes, but codec is returned as a string."""
- return self.get_meta(CODEC_KEY).decode()
-
- @codec.setter
- def codec(self, value):
- """Meta are stored as bytes, but codec is set as a string."""
- if value not in VALID_CODECS:
- raise DataFileException("Unknown codec: {!r}".format(value))
- self.set_meta(CODEC_KEY, value.encode())
-
- # TODO(hammer): make a schema for blocks and use datum_writer
- def _write_block(self):
- if not self._header_written:
- self._write_header()
-
- if self.block_count > 0:
- # write number of items in block
- self.encoder.write_long(self.block_count)
-
- # write block contents
- uncompressed_data = self.buffer_writer.getvalue()
- codec = Codecs.get_codec(self.codec)
- compressed_data, compressed_data_length = codec.compress(uncompressed_data)
-
- # Write length of block
- self.encoder.write_long(compressed_data_length)
-
- # Write block
- self.writer.write(compressed_data)
-
- # write sync marker
- self.writer.write(self.sync_marker)
-
- # reset buffer
- self.buffer_writer.truncate(0)
- self.buffer_writer.seek(0)
- self.block_count = 0
-
- def append(self, datum):
- """Append a datum to the file."""
- self.datum_writer.write(datum, self.buffer_encoder)
- self.block_count += 1
-
- # if the data to write is larger than the sync interval, write the block
- if self.buffer_writer.tell() >= SYNC_INTERVAL:
- self._write_block()
-
- def sync(self):
- """
- Return the current position as a value that may be passed to
- DataFileReader.seek(long). Forces the end of the current block,
- emitting a synchronization marker.
- """
- self._write_block()
- return self.writer.tell()
-
- def flush(self):
- """Flush the current state of the file, including metadata."""
- self._write_block()
- self.writer.flush()
-
- def close(self):
- """Close the file."""
- self.flush()
- self.writer.close()
class DataFileReader(_DataFile):
- """Read files written by DataFileWriter."""
- # TODO(hammer): allow user to specify expected schema?
- # TODO(hammer): allow user to specify the encoder
- def __init__(self, reader, datum_reader):
- self._reader = reader
- self._raw_decoder = avro.io.BinaryDecoder(reader)
- self._datum_decoder = None # Maybe reset at every block.
- self._datum_reader = datum_reader
-
- # read the header: magic, meta, sync
- self._read_header()
-
- # get file length
- self._file_length = self.determine_file_length()
-
- # get ready to read
- self.block_count = 0
- self.datum_reader.writers_schema = avro.schema.parse(self.schema)
-
- def __iter__(self):
- return self
-
- # read-only properties
- reader = property(lambda self: self._reader)
- raw_decoder = property(lambda self: self._raw_decoder)
- datum_decoder = property(lambda self: self._datum_decoder)
- datum_reader = property(lambda self: self._datum_reader)
- file_length = property(lambda self: self._file_length)
-
- def determine_file_length(self):
- """
- Get file length and leave file cursor where we found it.
- """
- remember_pos = self.reader.tell()
- self.reader.seek(0, 2)
- file_length = self.reader.tell()
- self.reader.seek(remember_pos)
- return file_length
-
- def is_EOF(self):
- return self.reader.tell() == self.file_length
-
- def _read_header(self):
- # seek to the beginning of the file to get magic block
- self.reader.seek(0, 0)
-
- # read header into a dict
- header = self.datum_reader.read_data(
- META_SCHEMA, META_SCHEMA, self.raw_decoder)
-
- # check magic number
- if header.get('magic') != MAGIC:
- fail_msg = "Not an Avro data file: %s doesn't match %s."\
- % (header.get('magic'), MAGIC)
- raise avro.schema.AvroException(fail_msg)
-
- # set metadata
- self._meta = header['meta']
-
- # set sync marker
- self._sync_marker = header['sync']
-
- def _read_block_header(self):
- self.block_count = self.raw_decoder.read_long()
- codec = Codecs.get_codec(self.codec)
- self._datum_decoder = codec.decompress(self.raw_decoder)
-
- def _skip_sync(self):
- """
- Read the length of the sync marker; if it matches the sync marker,
- return True. Otherwise, seek back to where we started and return False.
- """
- proposed_sync_marker = self.reader.read(SYNC_SIZE)
- if proposed_sync_marker != self.sync_marker:
- self.reader.seek(-SYNC_SIZE, 1)
- return False
- return True
-
- def __next__(self):
- """Return the next datum in the file."""
- while self.block_count == 0:
- if self.is_EOF() or (self._skip_sync() and self.is_EOF()):
- raise StopIteration
- self._read_block_header()
-
- datum = self.datum_reader.read(self.datum_decoder)
- self.block_count -= 1
- return datum
- next = __next__
-
- def close(self):
- """Close this reader."""
- self.reader.close()
+ """Read files written by DataFileWriter."""
+ # TODO(hammer): allow user to specify expected schema?
+ # TODO(hammer): allow user to specify the encoder
+
+ def __init__(self, reader, datum_reader):
+ self._reader = reader
+ self._raw_decoder = avro.io.BinaryDecoder(reader)
+ self._datum_decoder = None # Maybe reset at every block.
+ self._datum_reader = datum_reader
+
+ # read the header: magic, meta, sync
+ self._read_header()
+
+ # get file length
+ self._file_length = self.determine_file_length()
+
+ # get ready to read
+ self.block_count = 0
+ self.datum_reader.writers_schema = avro.schema.parse(self.schema)
+
+ def __iter__(self):
+ return self
+
+ # read-only properties
+ reader = property(lambda self: self._reader)
+ raw_decoder = property(lambda self: self._raw_decoder)
+ datum_decoder = property(lambda self: self._datum_decoder)
+ datum_reader = property(lambda self: self._datum_reader)
+ file_length = property(lambda self: self._file_length)
+
+ def determine_file_length(self):
+ """
+ Get file length and leave file cursor where we found it.
+ """
+ remember_pos = self.reader.tell()
+ self.reader.seek(0, 2)
+ file_length = self.reader.tell()
+ self.reader.seek(remember_pos)
+ return file_length
+
+ def is_EOF(self):
+ return self.reader.tell() == self.file_length
+
+ def _read_header(self):
+ # seek to the beginning of the file to get magic block
+ self.reader.seek(0, 0)
+
+ # read header into a dict
+ header = self.datum_reader.read_data(
+ META_SCHEMA, META_SCHEMA, self.raw_decoder)
+
+ # check magic number
+ if header.get('magic') != MAGIC:
+ fail_msg = "Not an Avro data file: %s doesn't match %s."\
+ % (header.get('magic'), MAGIC)
+ raise avro.schema.AvroException(fail_msg)
+
+ # set metadata
+ self._meta = header['meta']
+
+ # set sync marker
+ self._sync_marker = header['sync']
+
+ def _read_block_header(self):
+ self.block_count = self.raw_decoder.read_long()
+ codec = Codecs.get_codec(self.codec)
+ self._datum_decoder = codec.decompress(self.raw_decoder)
+
+ def _skip_sync(self):
+ """
+ Read the length of the sync marker; if it matches the sync marker,
+ return True. Otherwise, seek back to where we started and return False.
+ """
+ proposed_sync_marker = self.reader.read(SYNC_SIZE)
+ if proposed_sync_marker != self.sync_marker:
+ self.reader.seek(-SYNC_SIZE, 1)
+ return False
+ return True
+
+ def __next__(self):
+ """Return the next datum in the file."""
+ while self.block_count == 0:
+ if self.is_EOF() or (self._skip_sync() and self.is_EOF()):
+ raise StopIteration
+ self._read_block_header()
+
+ datum = self.datum_reader.read(self.datum_decoder)
+ self.block_count -= 1
+ return datum
+ next = __next__
+
+ def close(self):
+ """Close this reader."""
+ self.reader.close()
def generate_sixteen_random_bytes():
- try:
- return os.urandom(16)
- except NotImplementedError:
- return bytes(random.randrange(256) for i in range(16))
+ try:
+ return os.urandom(16)
+ except NotImplementedError:
+ return bytes(random.randrange(256) for i in range(16))
diff --git a/lang/py/avro/io.py b/lang/py/avro/io.py
index ffbd0af..b910ba5 100644
--- a/lang/py/avro/io.py
+++ b/lang/py/avro/io.py
@@ -53,19 +53,19 @@ from struct import Struct
from avro import constants, schema, timezones
try:
- unicode
+ unicode
except NameError:
- unicode = str
+ unicode = str
try:
- basestring # type: ignore
+ basestring # type: ignore
except NameError:
- basestring = (bytes, unicode)
+ basestring = (bytes, unicode)
try:
- long
+ long
except NameError:
- long = int
+ long = int
#
@@ -93,94 +93,103 @@ STRUCT_SIGNED_LONG = Struct('>q') # big-endian signed long
#
class AvroTypeException(schema.AvroException):
- """Raised when datum is not an example of schema."""
- def __init__(self, expected_schema, datum):
- pretty_expected = json.dumps(json.loads(str(expected_schema)), indent=2)
- fail_msg = "The datum %s is not an example of the schema %s"\
- % (datum, pretty_expected)
- schema.AvroException.__init__(self, fail_msg)
+ """Raised when datum is not an example of schema."""
+
+ def __init__(self, expected_schema, datum):
+ pretty_expected = json.dumps(json.loads(str(expected_schema)), indent=2)
+ fail_msg = "The datum %s is not an example of the schema %s"\
+ % (datum, pretty_expected)
+ schema.AvroException.__init__(self, fail_msg)
+
class SchemaResolutionException(schema.AvroException):
- def __init__(self, fail_msg, writers_schema=None, readers_schema=None):
- pretty_writers = json.dumps(json.loads(str(writers_schema)), indent=2)
- pretty_readers = json.dumps(json.loads(str(readers_schema)), indent=2)
- if writers_schema: fail_msg += "\nWriter's Schema: %s" % pretty_writers
- if readers_schema: fail_msg += "\nReader's Schema: %s" % pretty_readers
- schema.AvroException.__init__(self, fail_msg)
+ def __init__(self, fail_msg, writers_schema=None, readers_schema=None):
+ pretty_writers = json.dumps(json.loads(str(writers_schema)), indent=2)
+ pretty_readers = json.dumps(json.loads(str(readers_schema)), indent=2)
+ if writers_schema:
+ fail_msg += "\nWriter's Schema: %s" % pretty_writers
+ if readers_schema:
+ fail_msg += "\nReader's Schema: %s" % pretty_readers
+ schema.AvroException.__init__(self, fail_msg)
#
# Validate
#
+
+
def _is_timezone_aware_datetime(dt):
- return dt.tzinfo is not None and dt.tzinfo.utcoffset(dt) is not None
+ return dt.tzinfo is not None and dt.tzinfo.utcoffset(dt) is not None
+
_valid = {
- 'null': lambda s, d: d is None,
- 'boolean': lambda s, d: isinstance(d, bool),
- 'string': lambda s, d: isinstance(d, unicode),
- 'bytes': lambda s, d: ((isinstance(d, bytes)) or
- (isinstance(d, Decimal) and
- getattr(s, 'logical_type', None) == constants.DECIMAL)),
- 'int': lambda s, d: ((isinstance(d, (int, long))) and (INT_MIN_VALUE <= d <= INT_MAX_VALUE) or
- (isinstance(d, datetime.date) and
- getattr(s, 'logical_type', None) == constants.DATE) or
- (isinstance(d, datetime.time) and
- getattr(s, 'logical_type', None) == constants.TIME_MILLIS)),
- 'long': lambda s, d: ((isinstance(d, (int, long))) and (LONG_MIN_VALUE <= d <= LONG_MAX_VALUE) or
- (isinstance(d, datetime.time) and
- getattr(s, 'logical_type', None) == constants.TIME_MICROS) or
- (isinstance(d, datetime.date) and
- _is_timezone_aware_datetime(d) and
- getattr(s, 'logical_type', None) in (constants.TIMESTAMP_MILLIS,
- constants.TIMESTAMP_MICROS))),
- 'float': lambda s, d: isinstance(d, (int, long, float)),
- 'fixed': lambda s, d: ((isinstance(d, bytes) and len(d) == s.size) or
- (isinstance(d, Decimal) and
- getattr(s, 'logical_type', None) == constants.DECIMAL)),
- 'enum': lambda s, d: d in s.symbols,
-
- 'array': lambda s, d: isinstance(d, list) and all(validate(s.items, item) for item in d),
- 'map': lambda s, d: (isinstance(d, dict) and all(isinstance(key, unicode) for key in d)
- and all(validate(s.values, value) for value in d.values())),
- 'union': lambda s, d: any(validate(branch, d) for branch in s.schemas),
- 'record': lambda s, d: (isinstance(d, dict)
- and all(validate(f.type, d.get(f.name)) for f in s.fields)
- and {f.name for f in s.fields}.issuperset(d.keys())),
+ 'null': lambda s, d: d is None,
+ 'boolean': lambda s, d: isinstance(d, bool),
+ 'string': lambda s, d: isinstance(d, unicode),
+ 'bytes': lambda s, d: ((isinstance(d, bytes)) or
+ (isinstance(d, Decimal) and
+ getattr(s, 'logical_type', None) == constants.DECIMAL)),
+ 'int': lambda s, d: ((isinstance(d, (int, long))) and (INT_MIN_VALUE <= d <= INT_MAX_VALUE) or
+ (isinstance(d, datetime.date) and
+ getattr(s, 'logical_type', None) == constants.DATE) or
+ (isinstance(d, datetime.time) and
+ getattr(s, 'logical_type', None) == constants.TIME_MILLIS)),
+ 'long': lambda s, d: ((isinstance(d, (int, long))) and (LONG_MIN_VALUE <= d <= LONG_MAX_VALUE) or
+ (isinstance(d, datetime.time) and
+ getattr(s, 'logical_type', None) == constants.TIME_MICROS) or
+ (isinstance(d, datetime.date) and
+ _is_timezone_aware_datetime(d) and
+ getattr(s, 'logical_type', None) in (constants.TIMESTAMP_MILLIS,
+ constants.TIMESTAMP_MICROS))),
+ 'float': lambda s, d: isinstance(d, (int, long, float)),
+ 'fixed': lambda s, d: ((isinstance(d, bytes) and len(d) == s.size) or
+ (isinstance(d, Decimal) and
+ getattr(s, 'logical_type', None) == constants.DECIMAL)),
+ 'enum': lambda s, d: d in s.symbols,
+
+ 'array': lambda s, d: isinstance(d, list) and all(validate(s.items, item) for item in d),
+ 'map': lambda s, d: (isinstance(d, dict) and all(isinstance(key, unicode) for key in d) and
+ all(validate(s.values, value) for value in d.values())),
+ 'union': lambda s, d: any(validate(branch, d) for branch in s.schemas),
+ 'record': lambda s, d: (isinstance(d, dict) and
+ all(validate(f.type, d.get(f.name)) for f in s.fields) and
+ {f.name for f in s.fields}.issuperset(d.keys())),
}
_valid['double'] = _valid['float']
_valid['error_union'] = _valid['union']
_valid['error'] = _valid['request'] = _valid['record']
+
def validate(expected_schema, datum):
- """Determines if a python datum is an instance of a schema.
-
- Args:
- expected_schema: Schema to validate against.
- datum: Datum to validate.
- Returns:
- True if the datum is an instance of the schema.
- """
- global _DEBUG_VALIDATE_INDENT
- global _DEBUG_VALIDATE
- expected_type = expected_schema.type
- name = getattr(expected_schema, 'name', '')
- if name:
- name = ' ' + name
- if expected_type in ('array', 'map', 'union', 'record'):
- if _DEBUG_VALIDATE:
- print('{!s}{!s}{!s}: {!s} {{'.format(' ' * _DEBUG_VALIDATE_INDENT, expected_schema.type, name, type(datum).__name__), file=sys.stderr)
- _DEBUG_VALIDATE_INDENT += 2
- if datum is not None and not datum:
- print('{!s}<Empty>'.format(' ' * _DEBUG_VALIDATE_INDENT), file=sys.stderr)
- result = _valid[expected_type](expected_schema, datum)
- if _DEBUG_VALIDATE:
- _DEBUG_VALIDATE_INDENT -= 2
- print('{!s}}} -> {!s}'.format(' ' * _DEBUG_VALIDATE_INDENT, result), file=sys.stderr)
- else:
- result = _valid[expected_type](expected_schema, datum)
- if _DEBUG_VALIDATE:
- print('{!s}{!s}{!s}: {!s} -> {!s}'.format(' ' * _DEBUG_VALIDATE_INDENT, expected_schema.type, name, type(datum).__name__, result), file=sys.stderr)
- return result
+ """Determines if a python datum is an instance of a schema.
+
+ Args:
+ expected_schema: Schema to validate against.
+ datum: Datum to validate.
+ Returns:
+ True if the datum is an instance of the schema.
+ """
+ global _DEBUG_VALIDATE_INDENT
+ global _DEBUG_VALIDATE
+ expected_type = expected_schema.type
+ name = getattr(expected_schema, 'name', '')
+ if name:
+ name = ' ' + name
+ if expected_type in ('array', 'map', 'union', 'record'):
+ if _DEBUG_VALIDATE:
+ print('{!s}{!s}{!s}: {!s} {{'.format(' ' * _DEBUG_VALIDATE_INDENT, expected_schema.type, name, type(datum).__name__), file=sys.stderr)
+ _DEBUG_VALIDATE_INDENT += 2
+ if datum is not None and not datum:
+ print('{!s}<Empty>'.format(' ' * _DEBUG_VALIDATE_INDENT), file=sys.stderr)
+ result = _valid[expected_type](expected_schema, datum)
+ if _DEBUG_VALIDATE:
+ _DEBUG_VALIDATE_INDENT -= 2
+ print('{!s}}} -> {!s}'.format(' ' * _DEBUG_VALIDATE_INDENT, result), file=sys.stderr)
+ else:
+ result = _valid[expected_type](expected_schema, datum)
+ if _DEBUG_VALIDATE:
+ print('{!s}{!s}{!s}: {!s} -> {!s}'.format(' ' * _DEBUG_VALIDATE_INDENT,
+ expected_schema.type, name, type(datum).__name__, result), file=sys.stderr)
+ return result
#
@@ -188,917 +197,927 @@ def validate(expected_schema, datum):
#
class BinaryDecoder(object):
- """Read leaf values."""
- def __init__(self, reader):
- """
- reader is a Python object on which we can call read, seek, and tell.
- """
- self._reader = reader
-
- # read-only properties
- reader = property(lambda self: self._reader)
-
- def read(self, n):
- """
- Read n bytes.
- """
- return self.reader.read(n)
-
- def read_null(self):
- """
- null is written as zero bytes
- """
- return None
-
- def read_boolean(self):
- """
- a boolean is written as a single byte
- whose value is either 0 (false) or 1 (true).
- """
- return ord(self.read(1)) == 1
-
- def read_int(self):
- """
- int and long values are written using variable-length, zig-zag coding.
- """
- return self.read_long()
-
- def read_long(self):
- """
- int and long values are written using variable-length, zig-zag coding.
- """
- b = ord(self.read(1))
- n = b & 0x7F
- shift = 7
- while (b & 0x80) != 0:
- b = ord(self.read(1))
- n |= (b & 0x7F) << shift
- shift += 7
- datum = (n >> 1) ^ -(n & 1)
- return datum
-
- def read_float(self):
- """
- A float is written as 4 bytes.
- The float is converted into a 32-bit integer using a method equivalent to
- Java's floatToIntBits and then encoded in little-endian format.
- """
- return STRUCT_FLOAT.unpack(self.read(4))[0]
-
- def read_double(self):
- """
- A double is written as 8 bytes.
- The double is converted into a 64-bit integer using a method equivalent to
- Java's doubleToLongBits and then encoded in little-endian format.
- """
- return STRUCT_DOUBLE.unpack(self.read(8))[0]
-
- def read_decimal_from_bytes(self, precision, scale):
- """
- Decimal bytes are decoded as signed short, int or long depending on the
- size of bytes.
- """
- size = self.read_long()
- return self.read_decimal_from_fixed(precision, scale, size)
-
- def read_decimal_from_fixed(self, precision, scale, size):
- """
- Decimal is encoded as fixed. Fixed instances are encoded using the
- number of bytes declared in the schema.
- """
- datum = self.read(size)
- unscaled_datum = 0
- msb = struct.unpack('!b', datum[0:1])[0]
- leftmost_bit = (msb >> 7) & 1
- if leftmost_bit == 1:
- modified_first_byte = ord(datum[0:1]) ^ (1 << 7)
- datum = bytearray([modified_first_byte]) + datum[1:]
- for offset in range(size):
- unscaled_datum <<= 8
- unscaled_datum += ord(datum[offset:1+offset])
- unscaled_datum += pow(-2, (size*8) - 1)
- else:
- for offset in range(size):
- unscaled_datum <<= 8
- unscaled_datum += ord(datum[offset:1+offset])
-
- original_prec = getcontext().prec
- getcontext().prec = precision
- scaled_datum = Decimal(unscaled_datum).scaleb(-scale)
- getcontext().prec = original_prec
- return scaled_datum
-
- def read_bytes(self):
- """
- Bytes are encoded as a long followed by that many bytes of data.
- """
- return self.read(self.read_long())
-
- def read_utf8(self):
- """
- A string is encoded as a long followed by
- that many bytes of UTF-8 encoded character data.
- """
- return unicode(self.read_bytes(), "utf-8")
-
- def read_date_from_int(self):
- """
- int is decoded as python date object.
- int stores the number of days from
- the unix epoch, 1 January 1970 (ISO calendar).
- """
- days_since_epoch = self.read_int()
- return datetime.date(1970, 1, 1) + datetime.timedelta(days_since_epoch)
-
- def _build_time_object(self, value, scale_to_micro):
- value = value * scale_to_micro
- value, microseconds = value // 1000000, value % 1000000
- value, seconds = value // 60, value % 60
- value, minutes = value // 60, value % 60
- hours = value
-
- return datetime.time(
- hour=hours,
- minute=minutes,
- second=seconds,
- microsecond=microseconds
- )
-
- def read_time_millis_from_int(self):
- """
- int is decoded as python time object which represents
- the number of milliseconds after midnight, 00:00:00.000.
- """
- milliseconds = self.read_int()
- return self._build_time_object(milliseconds, 1000)
-
- def read_time_micros_from_long(self):
- """
- long is decoded as python time object which represents
- the number of microseconds after midnight, 00:00:00.000000.
- """
- microseconds = self.read_long()
- return self._build_time_object(microseconds, 1)
-
- def read_timestamp_millis_from_long(self):
- """
- long is decoded as python datetime object which represents
- the number of milliseconds from the unix epoch, 1 January 1970.
- """
- timestamp_millis = self.read_long()
- timedelta = datetime.timedelta(microseconds=timestamp_millis * 1000)
- unix_epoch_datetime = datetime.datetime(1970, 1, 1, 0, 0, 0, 0, tzinfo=timezones.utc)
- return unix_epoch_datetime + timedelta
-
- def read_timestamp_micros_from_long(self):
- """
- long is decoded as python datetime object which represents
- the number of microseconds from the unix epoch, 1 January 1970.
- """
- timestamp_micros = self.read_long()
- timedelta = datetime.timedelta(microseconds=timestamp_micros)
- unix_epoch_datetime = datetime.datetime(1970, 1, 1, 0, 0, 0, 0, tzinfo=timezones.utc)
- return unix_epoch_datetime + timedelta
-
- def skip_null(self):
- pass
-
- def skip_boolean(self):
- self.skip(1)
-
- def skip_int(self):
- self.skip_long()
-
- def skip_long(self):
- b = ord(self.read(1))
- while (b & 0x80) != 0:
- b = ord(self.read(1))
-
- def skip_float(self):
- self.skip(4)
-
- def skip_double(self):
- self.skip(8)
-
- def skip_bytes(self):
- self.skip(self.read_long())
+ """Read leaf values."""
+
+ def __init__(self, reader):
+ """
+ reader is a Python object on which we can call read, seek, and tell.
+ """
+ self._reader = reader
+
+ # read-only properties
+ reader = property(lambda self: self._reader)
+
+ def read(self, n):
+ """
+ Read n bytes.
+ """
+ return self.reader.read(n)
+
+ def read_null(self):
+ """
+ null is written as zero bytes
+ """
+ return None
+
+ def read_boolean(self):
+ """
+ a boolean is written as a single byte
+ whose value is either 0 (false) or 1 (true).
+ """
+ return ord(self.read(1)) == 1
+
+ def read_int(self):
+ """
+ int and long values are written using variable-length, zig-zag coding.
+ """
+ return self.read_long()
+
+ def read_long(self):
+ """
+ int and long values are written using variable-length, zig-zag coding.
+ """
+ b = ord(self.read(1))
+ n = b & 0x7F
+ shift = 7
+ while (b & 0x80) != 0:
+ b = ord(self.read(1))
+ n |= (b & 0x7F) << shift
+ shift += 7
+ datum = (n >> 1) ^ -(n & 1)
+ return datum
+
+ def read_float(self):
+ """
+ A float is written as 4 bytes.
+ The float is converted into a 32-bit integer using a method equivalent to
+ Java's floatToIntBits and then encoded in little-endian format.
+ """
+ return STRUCT_FLOAT.unpack(self.read(4))[0]
+
+ def read_double(self):
+ """
+ A double is written as 8 bytes.
+ The double is converted into a 64-bit integer using a method equivalent to
+ Java's doubleToLongBits and then encoded in little-endian format.
+ """
+ return STRUCT_DOUBLE.unpack(self.read(8))[0]
+
+ def read_decimal_from_bytes(self, precision, scale):
+ """
+ Decimal bytes are decoded as signed short, int or long depending on the
+ size of bytes.
+ """
+ size = self.read_long()
+ return self.read_decimal_from_fixed(precision, scale, size)
+
+ def read_decimal_from_fixed(self, precision, scale, size):
+ """
+ Decimal is encoded as fixed. Fixed instances are encoded using the
+ number of bytes declared in the schema.
+ """
+ datum = self.read(size)
+ unscaled_datum = 0
+ msb = struct.unpack('!b', datum[0:1])[0]
+ leftmost_bit = (msb >> 7) & 1
+ if leftmost_bit == 1:
+ modified_first_byte = ord(datum[0:1]) ^ (1 << 7)
+ datum = bytearray([modified_first_byte]) + datum[1:]
+ for offset in range(size):
+ unscaled_datum <<= 8
+ unscaled_datum += ord(datum[offset:1 + offset])
+ unscaled_datum += pow(-2, (size * 8) - 1)
+ else:
+ for offset in range(size):
+ unscaled_datum <<= 8
+ unscaled_datum += ord(datum[offset:1 + offset])
+
+ original_prec = getcontext().prec
+ getcontext().prec = precision
+ scaled_datum = Decimal(unscaled_datum).scaleb(-scale)
+ getcontext().prec = original_prec
+ return scaled_datum
+
+ def read_bytes(self):
+ """
+ Bytes are encoded as a long followed by that many bytes of data.
+ """
+ return self.read(self.read_long())
+
+ def read_utf8(self):
+ """
+ A string is encoded as a long followed by
+ that many bytes of UTF-8 encoded character data.
+ """
+ return unicode(self.read_bytes(), "utf-8")
+
+ def read_date_from_int(self):
+ """
+ int is decoded as python date object.
+ int stores the number of days from
+ the unix epoch, 1 January 1970 (ISO calendar).
+ """
+ days_since_epoch = self.read_int()
+ return datetime.date(1970, 1, 1) + datetime.timedelta(days_since_epoch)
+
+ def _build_time_object(self, value, scale_to_micro):
+ value = value * scale_to_micro
+ value, microseconds = value // 1000000, value % 1000000
+ value, seconds = value // 60, value % 60
+ value, minutes = value // 60, value % 60
+ hours = value
+
+ return datetime.time(
+ hour=hours,
+ minute=minutes,
+ second=seconds,
+ microsecond=microseconds
+ )
- def skip_utf8(self):
- self.skip_bytes()
+ def read_time_millis_from_int(self):
+ """
+ int is decoded as python time object which represents
+ the number of milliseconds after midnight, 00:00:00.000.
+ """
+ milliseconds = self.read_int()
+ return self._build_time_object(milliseconds, 1000)
+
+ def read_time_micros_from_long(self):
+ """
+ long is decoded as python time object which represents
+ the number of microseconds after midnight, 00:00:00.000000.
+ """
+ microseconds = self.read_long()
+ return self._build_time_object(microseconds, 1)
+
+ def read_timestamp_millis_from_long(self):
+ """
+ long is decoded as python datetime object which represents
+ the number of milliseconds from the unix epoch, 1 January 1970.
+ """
+ timestamp_millis = self.read_long()
+ timedelta = datetime.timedelta(microseconds=timestamp_millis * 1000)
+ unix_epoch_datetime = datetime.datetime(1970, 1, 1, 0, 0, 0, 0, tzinfo=timezones.utc)
+ return unix_epoch_datetime + timedelta
+
+ def read_timestamp_micros_from_long(self):
+ """
+ long is decoded as python datetime object which represents
+ the number of microseconds from the unix epoch, 1 January 1970.
+ """
+ timestamp_micros = self.read_long()
+ timedelta = datetime.timedelta(microseconds=timestamp_micros)
+ unix_epoch_datetime = datetime.datetime(1970, 1, 1, 0, 0, 0, 0, tzinfo=timezones.utc)
+ return unix_epoch_datetime + timedelta
+
+ def skip_null(self):
+ pass
+
+ def skip_boolean(self):
+ self.skip(1)
+
+ def skip_int(self):
+ self.skip_long()
+
+ def skip_long(self):
+ b = ord(self.read(1))
+ while (b & 0x80) != 0:
+ b = ord(self.read(1))
+
+ def skip_float(self):
+ self.skip(4)
+
+ def skip_double(self):
+ self.skip(8)
+
+ def skip_bytes(self):
+ self.skip(self.read_long())
+
+ def skip_utf8(self):
+ self.skip_bytes()
+
+ def skip(self, n):
+ self.reader.seek(self.reader.tell() + n)
- def skip(self, n):
- self.reader.seek(self.reader.tell() + n)
class BinaryEncoder(object):
- """Write leaf values."""
- def __init__(self, writer):
- """
- writer is a Python object on which we can call write.
- """
- self._writer = writer
-
- # read-only properties
- writer = property(lambda self: self._writer)
-
- def write(self, datum):
- """Write an arbitrary datum."""
- self.writer.write(datum)
-
- def write_null(self, datum):
- """
- null is written as zero bytes
- """
- pass
-
- def write_boolean(self, datum):
- """
- a boolean is written as a single byte
- whose value is either 0 (false) or 1 (true).
- """
- self.write(bytearray([bool(datum)]))
-
- def write_int(self, datum):
- """
- int and long values are written using variable-length, zig-zag coding.
- """
- self.write_long(datum);
-
- def write_long(self, datum):
- """
- int and long values are written using variable-length, zig-zag coding.
- """
- datum = (datum << 1) ^ (datum >> 63)
- while (datum & ~0x7F) != 0:
- self.write(bytearray([(datum & 0x7f) | 0x80]))
- datum >>= 7
- self.write(bytearray([datum]))
-
- def write_float(self, datum):
- """
- A float is written as 4 bytes.
- The float is converted into a 32-bit integer using a method equivalent to
- Java's floatToIntBits and then encoded in little-endian format.
- """
- self.write(STRUCT_FLOAT.pack(datum))
-
- def write_double(self, datum):
- """
- A double is written as 8 bytes.
- The double is converted into a 64-bit integer using a method equivalent to
- Java's doubleToLongBits and then encoded in little-endian format.
- """
- self.write(STRUCT_DOUBLE.pack(datum))
-
- def write_decimal_bytes(self, datum, scale):
- """
- Decimal in bytes are encoded as long. Since size of packed value in bytes for
- signed long is 8, 8 bytes are written.
- """
- sign, digits, exp = datum.as_tuple()
- if exp > scale:
- raise AvroTypeException('Scale provided in schema does not match the decimal')
-
- unscaled_datum = 0
- for digit in digits:
- unscaled_datum = (unscaled_datum * 10) + digit
-
- bits_req = unscaled_datum.bit_length() + 1
- if sign:
- unscaled_datum = (1 << bits_req) - unscaled_datum
-
- bytes_req = bits_req // 8
- padding_bits = ~((1 << bits_req) - 1) if sign else 0
- packed_bits = padding_bits | unscaled_datum
-
- bytes_req += 1 if (bytes_req << 3) < bits_req else 0
- self.write_long(bytes_req)
- for index in range(bytes_req-1, -1, -1):
- bits_to_write = packed_bits >> (8 * index)
- self.write(bytearray([bits_to_write & 0xff]))
-
- def write_decimal_fixed(self, datum, scale, size):
- """
- Decimal in fixed are encoded as size of fixed bytes.
- """
- sign, digits, exp = datum.as_tuple()
- if exp > scale:
- raise AvroTypeException('Scale provided in schema does not match the decimal')
-
- unscaled_datum = 0
- for digit in digits:
- unscaled_datum = (unscaled_datum * 10) + digit
-
- bits_req = unscaled_datum.bit_length() + 1
- size_in_bits = size * 8
- offset_bits = size_in_bits - bits_req
-
- mask = 2 ** size_in_bits - 1
- bit = 1
- for i in range(bits_req):
- mask ^= bit
- bit <<= 1
-
- if bits_req < 8:
- bytes_req = 1
- else:
- bytes_req = bits_req // 8
- if bits_req % 8 != 0:
- bytes_req += 1
- if sign:
- unscaled_datum = (1 << bits_req) - unscaled_datum
- unscaled_datum = mask | unscaled_datum
- for index in range(size-1, -1, -1):
- bits_to_write = unscaled_datum >> (8 * index)
- self.write(bytearray([bits_to_write & 0xff]))
- else:
- for i in range(offset_bits // 8):
- self.write(b'\x00')
- for index in range(bytes_req-1, -1, -1):
- bits_to_write = unscaled_datum >> (8 * index)
- self.write(bytearray([bits_to_write & 0xff]))
-
- def write_bytes(self, datum):
- """
- Bytes are encoded as a long followed by that many bytes of data.
- """
- self.write_long(len(datum))
- self.write(struct.pack('%ds' % len(datum), datum))
-
- def write_utf8(self, datum):
- """
- A string is encoded as a long followed by
- that many bytes of UTF-8 encoded character data.
- """
- datum = datum.encode("utf-8")
- self.write_bytes(datum)
-
- def write_date_int(self, datum):
- """
- Encode python date object as int.
- It stores the number of days from
- the unix epoch, 1 January 1970 (ISO calendar).
- """
- delta_date = datum - datetime.date(1970, 1, 1)
- self.write_int(delta_date.days)
-
- def write_time_millis_int(self, datum):
- """
- Encode python time object as int.
- It stores the number of milliseconds from midnight, 00:00:00.000
- """
- milliseconds = datum.hour*3600000 + datum.minute * 60000 + datum.second * 1000 + datum.microsecond // 1000
- self.write_int(milliseconds)
-
- def write_time_micros_long(self, datum):
- """
- Encode python time object as long.
- It stores the number of microseconds from midnight, 00:00:00.000000
- """
- microseconds = datum.hour*3600000000 + datum.minute * 60000000 + datum.second * 1000000 + datum.microsecond
- self.write_long(microseconds)
-
- def _timedelta_total_microseconds(self, timedelta):
- return (
- timedelta.microseconds + (timedelta.seconds + timedelta.days * 24 * 3600) * 10 ** 6)
-
- def write_timestamp_millis_long(self, datum):
- """
- Encode python datetime object as long.
- It stores the number of milliseconds from midnight of unix epoch, 1 January 1970.
- """
- datum = datum.astimezone(tz=timezones.utc)
- timedelta = datum - datetime.datetime(1970, 1, 1, 0, 0, 0, 0, tzinfo=timezones.utc)
- milliseconds = self._timedelta_total_microseconds(timedelta) / 1000
- self.write_long(long(milliseconds))
-
- def write_timestamp_micros_long(self, datum):
- """
- Encode python datetime object as long.
- It stores the number of microseconds from midnight of unix epoch, 1 January 1970.
- """
- datum = datum.astimezone(tz=timezones.utc)
- timedelta = datum - datetime.datetime(1970, 1, 1, 0, 0, 0, 0, tzinfo=timezones.utc)
- microseconds = self._timedelta_total_microseconds(timedelta)
- self.write_long(long(microseconds))
+ """Write leaf values."""
+
+ def __init__(self, writer):
+ """
+ writer is a Python object on which we can call write.
+ """
+ self._writer = writer
+
+ # read-only properties
+ writer = property(lambda self: self._writer)
+
+ def write(self, datum):
+ """Write an arbitrary datum."""
+ self.writer.write(datum)
+
+ def write_null(self, datum):
+ """
+ null is written as zero bytes
+ """
+ pass
+
+ def write_boolean(self, datum):
+ """
+ a boolean is written as a single byte
+ whose value is either 0 (false) or 1 (true).
+ """
+ self.write(bytearray([bool(datum)]))
+
+ def write_int(self, datum):
+ """
+ int and long values are written using variable-length, zig-zag coding.
+ """
+ self.write_long(datum)
+
+ def write_long(self, datum):
+ """
+ int and long values are written using variable-length, zig-zag coding.
+ """
+ datum = (datum << 1) ^ (datum >> 63)
+ while (datum & ~0x7F) != 0:
+ self.write(bytearray([(datum & 0x7f) | 0x80]))
+ datum >>= 7
+ self.write(bytearray([datum]))
+
+ def write_float(self, datum):
+ """
+ A float is written as 4 bytes.
+ The float is converted into a 32-bit integer using a method equivalent to
+ Java's floatToIntBits and then encoded in little-endian format.
+ """
+ self.write(STRUCT_FLOAT.pack(datum))
+
+ def write_double(self, datum):
+ """
+ A double is written as 8 bytes.
+ The double is converted into a 64-bit integer using a method equivalent to
+ Java's doubleToLongBits and then encoded in little-endian format.
+ """
+ self.write(STRUCT_DOUBLE.pack(datum))
+
+ def write_decimal_bytes(self, datum, scale):
+ """
+ Decimal in bytes are encoded as long. Since size of packed value in bytes for
+ signed long is 8, 8 bytes are written.
+ """
+ sign, digits, exp = datum.as_tuple()
+ if exp > scale:
+ raise AvroTypeException('Scale provided in schema does not match the decimal')
+
+ unscaled_datum = 0
+ for digit in digits:
+ unscaled_datum = (unscaled_datum * 10) + digit
+
+ bits_req = unscaled_datum.bit_length() + 1
+ if sign:
+ unscaled_datum = (1 << bits_req) - unscaled_datum
+
+ bytes_req = bits_req // 8
+ padding_bits = ~((1 << bits_req) - 1) if sign else 0
+ packed_bits = padding_bits | unscaled_datum
+
+ bytes_req += 1 if (bytes_req << 3) < bits_req else 0
+ self.write_long(bytes_req)
+ for index in range(bytes_req - 1, -1, -1):
+ bits_to_write = packed_bits >> (8 * index)
+ self.write(bytearray([bits_to_write & 0xff]))
+
+ def write_decimal_fixed(self, datum, scale, size):
+ """
+ Decimal in fixed are encoded as size of fixed bytes.
+ """
+ sign, digits, exp = datum.as_tuple()
+ if exp > scale:
+ raise AvroTypeException('Scale provided in schema does not match the decimal')
+
+ unscaled_datum = 0
+ for digit in digits:
+ unscaled_datum = (unscaled_datum * 10) + digit
+
+ bits_req = unscaled_datum.bit_length() + 1
+ size_in_bits = size * 8
+ offset_bits = size_in_bits - bits_req
+
+ mask = 2 ** size_in_bits - 1
+ bit = 1
+ for i in range(bits_req):
+ mask ^= bit
+ bit <<= 1
+
+ if bits_req < 8:
+ bytes_req = 1
+ else:
+ bytes_req = bits_req // 8
+ if bits_req % 8 != 0:
+ bytes_req += 1
+ if sign:
+ unscaled_datum = (1 << bits_req) - unscaled_datum
+ unscaled_datum = mask | unscaled_datum
+ for index in range(size - 1, -1, -1):
+ bits_to_write = unscaled_datum >> (8 * index)
+ self.write(bytearray([bits_to_write & 0xff]))
+ else:
+ for i in range(offset_bits // 8):
+ self.write(b'\x00')
+ for index in range(bytes_req - 1, -1, -1):
+ bits_to_write = unscaled_datum >> (8 * index)
+ self.write(bytearray([bits_to_write & 0xff]))
+
+ def write_bytes(self, datum):
+ """
+ Bytes are encoded as a long followed by that many bytes of data.
+ """
+ self.write_long(len(datum))
+ self.write(struct.pack('%ds' % len(datum), datum))
+
+ def write_utf8(self, datum):
+ """
+ A string is encoded as a long followed by
+ that many bytes of UTF-8 encoded character data.
+ """
+ datum = datum.encode("utf-8")
+ self.write_bytes(datum)
+
+ def write_date_int(self, datum):
+ """
+ Encode python date object as int.
+ It stores the number of days from
+ the unix epoch, 1 January 1970 (ISO calendar).
+ """
+ delta_date = datum - datetime.date(1970, 1, 1)
+ self.write_int(delta_date.days)
+
+ def write_time_millis_int(self, datum):
+ """
+ Encode python time object as int.
+ It stores the number of milliseconds from midnight, 00:00:00.000
+ """
+ milliseconds = datum.hour * 3600000 + datum.minute * 60000 + datum.second * 1000 + datum.microsecond // 1000
+ self.write_int(milliseconds)
+
+ def write_time_micros_long(self, datum):
+ """
+ Encode python time object as long.
+ It stores the number of microseconds from midnight, 00:00:00.000000
+ """
+ microseconds = datum.hour * 3600000000 + datum.minute * 60000000 + datum.second * 1000000 + datum.microsecond
+ self.write_long(microseconds)
+
+ def _timedelta_total_microseconds(self, timedelta):
+ return (
+ timedelta.microseconds + (timedelta.seconds + timedelta.days * 24 * 3600) * 10 ** 6)
+
+ def write_timestamp_millis_long(self, datum):
+ """
+ Encode python datetime object as long.
+ It stores the number of milliseconds from midnight of unix epoch, 1 January 1970.
+ """
+ datum = datum.astimezone(tz=timezones.utc)
+ timedelta = datum - datetime.datetime(1970, 1, 1, 0, 0, 0, 0, tzinfo=timezones.utc)
+ milliseconds = self._timedelta_total_microseconds(timedelta) / 1000
+ self.write_long(long(milliseconds))
+
+ def write_timestamp_micros_long(self, datum):
+ """
+ Encode python datetime object as long.
+ It stores the number of microseconds from midnight of unix epoch, 1 January 1970.
+ """
+ datum = datum.astimezone(tz=timezones.utc)
+ timedelta = datum - datetime.datetime(1970, 1, 1, 0, 0, 0, 0, tzinfo=timezones.utc)
+ microseconds = self._timedelta_total_microseconds(timedelta)
+ self.write_long(long(microseconds))
#
# DatumReader/Writer
#
class DatumReader(object):
- """Deserialize Avro-encoded data into a Python data structure."""
- def __init__(self, writers_schema=None, readers_schema=None):
- """
- As defined in the Avro specification, we call the schema encoded
- in the data the "writer's schema", and the schema expected by the
- reader the "reader's schema".
- """
- self._writers_schema = writers_schema
- self._readers_schema = readers_schema
-
- # read/write properties
- def set_writers_schema(self, writers_schema):
- self._writers_schema = writers_schema
- writers_schema = property(lambda self: self._writers_schema,
- set_writers_schema)
- def set_readers_schema(self, readers_schema):
- self._readers_schema = readers_schema
- readers_schema = property(lambda self: self._readers_schema,
- set_readers_schema)
- def read(self, decoder):
- if self.readers_schema is None:
- self.readers_schema = self.writers_schema
- return self.read_data(self.writers_schema, self.readers_schema, decoder)
-
- def read_data(self, writers_schema, readers_schema, decoder):
- # schema matching
- if not readers_schema.match(writers_schema):
- fail_msg = 'Schemas do not match.'
- raise SchemaResolutionException(fail_msg, writers_schema, readers_schema)
-
- logical_type = getattr(writers_schema, 'logical_type', None)
-
- # function dispatch for reading data based on type of writer's schema
- if writers_schema.type in ['union', 'error_union']:
- return self.read_union(writers_schema, readers_schema, decoder)
-
- if readers_schema.type in ['union', 'error_union']:
- # schema resolution: reader's schema is a union, writer's schema is not
- for s in readers_schema.schemas:
- if s.match(writers_schema):
- return self.read_data(writers_schema, s, decoder)
-
- # This shouldn't happen because of the match check at the start of this method.
- fail_msg = 'Schemas do not match.'
- raise SchemaResolutionException(fail_msg, writers_schema, readers_schema)
-
- if writers_schema.type == 'null':
- return decoder.read_null()
- elif writers_schema.type == 'boolean':
- return decoder.read_boolean()
- elif writers_schema.type == 'string':
- return decoder.read_utf8()
- elif writers_schema.type == 'int':
- if logical_type == constants.DATE:
- return decoder.read_date_from_int()
- if logical_type == constants.TIME_MILLIS:
- return decoder.read_time_millis_from_int()
- return decoder.read_int()
- elif writers_schema.type == 'long':
- if logical_type == constants.TIME_MICROS:
- return decoder.read_time_micros_from_long()
- elif logical_type == constants.TIMESTAMP_MILLIS:
- return decoder.read_timestamp_millis_from_long()
- elif logical_type == constants.TIMESTAMP_MICROS:
- return decoder.read_timestamp_micros_from_long()
- else:
- return decoder.read_long()
- elif writers_schema.type == 'float':
- return decoder.read_float()
- elif writers_schema.type == 'double':
- return decoder.read_double()
- elif writers_schema.type == 'bytes':
- if logical_type == 'decimal':
- return decoder.read_decimal_from_bytes(
- writers_schema.get_prop('precision'),
- writers_schema.get_prop('scale')
- )
- else:
- return decoder.read_bytes()
- elif writers_schema.type == 'fixed':
- if logical_type == 'decimal':
- return decoder.read_decimal_from_fixed(
- writers_schema.get_prop('precision'),
- writers_schema.get_prop('scale'),
- writers_schema.size
- )
- return self.read_fixed(writers_schema, readers_schema, decoder)
- elif writers_schema.type == 'enum':
- return self.read_enum(writers_schema, readers_schema, decoder)
- elif writers_schema.type == 'array':
- return self.read_array(writers_schema, readers_schema, decoder)
- elif writers_schema.type == 'map':
- return self.read_map(writers_schema, readers_schema, decoder)
- elif writers_schema.type in ['record', 'error', 'request']:
- return self.read_record(writers_schema, readers_schema, decoder)
- else:
- fail_msg = "Cannot read unknown schema type: %s" % writers_schema.type
- raise schema.AvroException(fail_msg)
-
- def skip_data(self, writers_schema, decoder):
- if writers_schema.type == 'null':
- return decoder.skip_null()
- elif writers_schema.type == 'boolean':
- return decoder.skip_boolean()
- elif writers_schema.type == 'string':
- return decoder.skip_utf8()
- elif writers_schema.type == 'int':
- return decoder.skip_int()
- elif writers_schema.type == 'long':
- return decoder.skip_long()
- elif writers_schema.type == 'float':
- return decoder.skip_float()
- elif writers_schema.type == 'double':
- return decoder.skip_double()
- elif writers_schema.type == 'bytes':
- return decoder.skip_bytes()
- elif writers_schema.type == 'fixed':
- return self.skip_fixed(writers_schema, decoder)
- elif writers_schema.type == 'enum':
- return self.skip_enum(writers_schema, decoder)
- elif writers_schema.type == 'array':
- return self.skip_array(writers_schema, decoder)
- elif writers_schema.type == 'map':
- return self.skip_map(writers_schema, decoder)
- elif writers_schema.type in ['union', 'error_union']:
- return self.skip_union(writers_schema, decoder)
- elif writers_schema.type in ['record', 'error', 'request']:
- return self.skip_record(writers_schema, decoder)
- else:
- fail_msg = "Unknown schema type: %s" % writers_schema.type
- raise schema.AvroException(fail_msg)
-
- def read_fixed(self, writers_schema, readers_schema, decoder):
- """
- Fixed instances are encoded using the number of bytes declared
- in the schema.
- """
- return decoder.read(writers_schema.size)
-
- def skip_fixed(self, writers_schema, decoder):
- return decoder.skip(writers_schema.size)
-
- def read_enum(self, writers_schema, readers_schema, decoder):
- """
- An enum is encoded by a int, representing the zero-based position
- of the symbol in the schema.
- """
- # read data
- index_of_symbol = decoder.read_int()
- if index_of_symbol >= len(writers_schema.symbols):
- fail_msg = "Can't access enum index %d for enum with %d symbols"\
- % (index_of_symbol, len(writers_schema.symbols))
- raise SchemaResolutionException(fail_msg, writers_schema, readers_schema)
- read_symbol = writers_schema.symbols[index_of_symbol]
-
- # schema resolution
- if read_symbol not in readers_schema.symbols:
- fail_msg = "Symbol %s not present in Reader's Schema" % read_symbol
- raise SchemaResolutionException(fail_msg, writers_schema, readers_schema)
-
- return read_symbol
+ """Deserialize Avro-encoded data into a Python data structure."""
+
+ def __init__(self, writers_schema=None, readers_schema=None):
+ """
+ As defined in the Avro specification, we call the schema encoded
+ in the data the "writer's schema", and the schema expected by the
+ reader the "reader's schema".
+ """
+ self._writers_schema = writers_schema
+ self._readers_schema = readers_schema
+
+ # read/write properties
+ def set_writers_schema(self, writers_schema):
+ self._writers_schema = writers_schema
+ writers_schema = property(lambda self: self._writers_schema,
+ set_writers_schema)
+
+ def set_readers_schema(self, readers_schema):
+ self._readers_schema = readers_schema
+ readers_schema = property(lambda self: self._readers_schema,
+ set_readers_schema)
+
+ def read(self, decoder):
+ if self.readers_schema is None:
+ self.readers_schema = self.writers_schema
+ return self.read_data(self.writers_schema, self.readers_schema, decoder)
+
+ def read_data(self, writers_schema, readers_schema, decoder):
+ # schema matching
+ if not readers_schema.match(writers_schema):
+ fail_msg = 'Schemas do not match.'
+ raise SchemaResolutionException(fail_msg, writers_schema, readers_schema)
+
+ logical_type = getattr(writers_schema, 'logical_type', None)
+
+ # function dispatch for reading data based on type of writer's schema
+ if writers_schema.type in ['union', 'error_union']:
+ return self.read_union(writers_schema, readers_schema, decoder)
+
+ if readers_schema.type in ['union', 'error_union']:
+ # schema resolution: reader's schema is a union, writer's schema is not
+ for s in readers_schema.schemas:
+ if s.match(writers_schema):
+ return self.read_data(writers_schema, s, decoder)
+
+ # This shouldn't happen because of the match check at the start of this method.
+ fail_msg = 'Schemas do not match.'
+ raise SchemaResolutionException(fail_msg, writers_schema, readers_schema)
+
+ if writers_schema.type == 'null':
+ return decoder.read_null()
+ elif writers_schema.type == 'boolean':
+ return decoder.read_boolean()
+ elif writers_schema.type == 'string':
+ return decoder.read_utf8()
+ elif writers_schema.type == 'int':
+ if logical_type == constants.DATE:
+ return decoder.read_date_from_int()
+ if logical_type == constants.TIME_MILLIS:
+ return decoder.read_time_millis_from_int()
+ return decoder.read_int()
+ elif writers_schema.type == 'long':
+ if logical_type == constants.TIME_MICROS:
+ return decoder.read_time_micros_from_long()
+ elif logical_type == constants.TIMESTAMP_MILLIS:
+ return decoder.read_timestamp_millis_from_long()
+ elif logical_type == constants.TIMESTAMP_MICROS:
+ return decoder.read_timestamp_micros_from_long()
+ else:
+ return decoder.read_long()
+ elif writers_schema.type == 'float':
+ return decoder.read_float()
+ elif writers_schema.type == 'double':
+ return decoder.read_double()
+ elif writers_schema.type == 'bytes':
+ if logical_type == 'decimal':
+ return decoder.read_decimal_from_bytes(
+ writers_schema.get_prop('precision'),
+ writers_schema.get_prop('scale')
+ )
+ else:
+ return decoder.read_bytes()
+ elif writers_schema.type == 'fixed':
+ if logical_type == 'decimal':
+ return decoder.read_decimal_from_fixed(
+ writers_schema.get_prop('precision'),
+ writers_schema.get_prop('scale'),
+ writers_schema.size
+ )
+ return self.read_fixed(writers_schema, readers_schema, decoder)
+ elif writers_schema.type == 'enum':
+ return self.read_enum(writers_schema, readers_schema, decoder)
+ elif writers_schema.type == 'array':
+ return self.read_array(writers_schema, readers_schema, decoder)
+ elif writers_schema.type == 'map':
+ return self.read_map(writers_schema, readers_schema, decoder)
+ elif writers_schema.type in ['record', 'error', 'request']:
+ return self.read_record(writers_schema, readers_schema, decoder)
+ else:
+ fail_msg = "Cannot read unknown schema type: %s" % writers_schema.type
+ raise schema.AvroException(fail_msg)
+
+ def skip_data(self, writers_schema, decoder):
+ if writers_schema.type == 'null':
+ return decoder.skip_null()
+ elif writers_schema.type == 'boolean':
+ return decoder.skip_boolean()
+ elif writers_schema.type == 'string':
+ return decoder.skip_utf8()
+ elif writers_schema.type == 'int':
+ return decoder.skip_int()
+ elif writers_schema.type == 'long':
+ return decoder.skip_long()
+ elif writers_schema.type == 'float':
+ return decoder.skip_float()
+ elif writers_schema.type == 'double':
+ return decoder.skip_double()
+ elif writers_schema.type == 'bytes':
+ return decoder.skip_bytes()
+ elif writers_schema.type == 'fixed':
+ return self.skip_fixed(writers_schema, decoder)
+ elif writers_schema.type == 'enum':
+ return self.skip_enum(writers_schema, decoder)
+ elif writers_schema.type == 'array':
+ return self.skip_array(writers_schema, decoder)
+ elif writers_schema.type == 'map':
+ return self.skip_map(writers_schema, decoder)
+ elif writers_schema.type in ['union', 'error_union']:
+ return self.skip_union(writers_schema, decoder)
+ elif writers_schema.type in ['record', 'error', 'request']:
+ return self.skip_record(writers_schema, decoder)
+ else:
+ fail_msg = "Unknown schema type: %s" % writers_schema.type
+ raise schema.AvroException(fail_msg)
+
+ def read_fixed(self, writers_schema, readers_schema, decoder):
+ """
+ Fixed instances are encoded using the number of bytes declared
+ in the schema.
+ """
+ return decoder.read(writers_schema.size)
+
+ def skip_fixed(self, writers_schema, decoder):
+ return decoder.skip(writers_schema.size)
+
+ def read_enum(self, writers_schema, readers_schema, decoder):
+ """
+ An enum is encoded by a int, representing the zero-based position
+ of the symbol in the schema.
+ """
+ # read data
+ index_of_symbol = decoder.read_int()
+ if index_of_symbol >= len(writers_schema.symbols):
+ fail_msg = "Can't access enum index %d for enum with %d symbols"\
+ % (index_of_symbol, len(writers_schema.symbols))
+ raise SchemaResolutionException(fail_msg, writers_schema, readers_schema)
+ read_symbol = writers_schema.symbols[index_of_symbol]
+
+ # schema resolution
+ if read_symbol not in readers_schema.symbols:
+ fail_msg = "Symbol %s not present in Reader's Schema" % read_symbol
+ raise SchemaResolutionException(fail_msg, writers_schema, readers_schema)
+
+ return read_symbol
+
+ def skip_enum(self, writers_schema, decoder):
+ return decoder.skip_int()
+
+ def read_array(self, writers_schema, readers_schema, decoder):
+ """
+ Arrays are encoded as a series of blocks.
+
+ Each block consists of a long count value,
+ followed by that many array items.
+ A block with count zero indicates the end of the array.
+ Each item is encoded per the array's item schema.
+
+ If a block's count is negative,
+ then the count is followed immediately by a long block size,
+ indicating the number of bytes in the block.
+ The actual count in this case
+ is the absolute value of the count written.
+ """
+ read_items = []
+ block_count = decoder.read_long()
+ while block_count != 0:
+ if block_count < 0:
+ block_count = -block_count
+ block_size = decoder.read_long()
+ for i in range(block_count):
+ read_items.append(self.read_data(writers_schema.items,
+ readers_schema.items, decoder))
+ block_count = decoder.read_long()
+ return read_items
+
+ def skip_array(self, writers_schema, decoder):
+ block_count = decoder.read_long()
+ while block_count != 0:
+ if block_count < 0:
+ block_size = decoder.read_long()
+ decoder.skip(block_size)
+ else:
+ for i in range(block_count):
+ self.skip_data(writers_schema.items, decoder)
+ block_count = decoder.read_long()
+
+ def read_map(self, writers_schema, readers_schema, decoder):
+ """
+ Maps are encoded as a series of blocks.
+
+ Each block consists of a long count value,
+ followed by that many key/value pairs.
+ A block with count zero indicates the end of the map.
+ Each item is encoded per the map's value schema.
+
+ If a block's count is negative,
+ then the count is followed immediately by a long block size,
+ indicating the number of bytes in the block.
+ The actual count in this case
+ is the absolute value of the count written.
+ """
+ read_items = {}
+ block_count = decoder.read_long()
+ while block_count != 0:
+ if block_count < 0:
+ block_count = -block_count
+ block_size = decoder.read_long()
+ for i in range(block_count):
+ key = decoder.read_utf8()
+ read_items[key] = self.read_data(writers_schema.values,
+ readers_schema.values, decoder)
+ block_count = decoder.read_long()
+ return read_items
+
+ def skip_map(self, writers_schema, decoder):
+ block_count = decoder.read_long()
+ while block_count != 0:
+ if block_count < 0:
+ block_size = decoder.read_long()
+ decoder.skip(block_size)
+ else:
+ for i in range(block_count):
+ decoder.skip_utf8()
+ self.skip_data(writers_schema.values, decoder)
+ block_count = decoder.read_long()
+
+ def read_union(self, writers_schema, readers_schema, decoder):
+ """
+ A union is encoded by first writing a long value indicating
+ the zero-based position within the union of the schema of its value.
+ The value is then encoded per the indicated schema within the union.
+ """
+ # schema resolution
+ index_of_schema = int(decoder.read_long())
+ if index_of_schema >= len(writers_schema.schemas):
+ fail_msg = "Can't access branch index %d for union with %d branches"\
+ % (index_of_schema, len(writers_schema.schemas))
+ raise SchemaResolutionException(fail_msg, writers_schema, readers_schema)
+ selected_writers_schema = writers_schema.schemas[index_of_schema]
+
+ # read data
+ return self.read_data(selected_writers_schema, readers_schema, decoder)
+
+ def skip_union(self, writers_schema, decoder):
+ index_of_schema = int(decoder.read_long())
+ if index_of_schema >= len(writers_schema.schemas):
+ fail_msg = "Can't access branch index %d for union with %d branches"\
+ % (index_of_schema, len(writers_schema.schemas))
+ raise SchemaResolutionException(fail_msg, writers_schema)
+ return self.skip_data(writers_schema.schemas[index_of_schema], decoder)
+
+ def read_record(self, writers_schema, readers_schema, decoder):
+ """
+ A record is encoded by encoding the values of its fields
+ in the order that they are declared. In other words, a record
+ is encoded as just the concatenation of the encodings of its fields.
+ Field values are encoded per their schema.
+
+ Schema Resolution:
+ * the ordering of fields may be different: fields are matched by name.
+ * schemas for fields with the same name in both records are resolved
+ recursively.
+ * if the writer's record contains a field with a name not present in the
+ reader's record, the writer's value for that field is ignored.
+ * if the reader's record schema has a field that contains a default value,
+ and writer's schema does not have a field with the same name, then the
+ reader should use the default value from its field.
+ * if the reader's record schema has a field with no default value, and
+ writer's schema does not have a field with the same name, then the
+ field's value is unset.
+ """
+ # schema resolution
+ readers_fields_dict = readers_schema.fields_dict
+ read_record = {}
+ for field in writers_schema.fields:
+ readers_field = readers_fields_dict.get(field.name)
+ if readers_field is not None:
+ field_val = self.read_data(field.type, readers_field.type, decoder)
+ read_record[field.name] = field_val
+ else:
+ self.skip_data(field.type, decoder)
+
+ # fill in default values
+ if len(readers_fields_dict) > len(read_record):
+ writers_fields_dict = writers_schema.fields_dict
+ for field_name, field in readers_fields_dict.items():
+ if field_name not in writers_fields_dict:
+ if field.has_default:
+ field_val = self._read_default_value(field.type, field.default)
+ read_record[field.name] = field_val
+ else:
+ fail_msg = 'No default value for field %s' % field_name
+ raise SchemaResolutionException(fail_msg, writers_schema,
+ readers_schema)
+ return read_record
+
+ def skip_record(self, writers_schema, decoder):
+ for field in writers_schema.fields:
+ self.skip_data(field.type, decoder)
+
+ def _read_default_value(self, field_schema, default_value):
+ """
+ Basically a JSON Decoder?
+ """
+ if field_schema.type == 'null':
+ return None
+ elif field_schema.type == 'boolean':
+ return bool(default_value)
+ elif field_schema.type == 'int':
+ return int(default_value)
+ elif field_schema.type == 'long':
+ return long(default_value)
+ elif field_schema.type in ['float', 'double']:
+ return float(default_value)
+ elif field_schema.type in ['enum', 'fixed', 'string', 'bytes']:
+ return default_value
+ elif field_schema.type == 'array':
+ read_array = []
+ for json_val in default_value:
+ item_val = self._read_default_value(field_schema.items, json_val)
+ read_array.append(item_val)
+ return read_array
+ elif field_schema.type == 'map':
+ read_map = {}
+ for key, json_val in default_value.items():
+ map_val = self._read_default_value(field_schema.values, json_val)
+ read_map[key] = map_val
+ return read_map
+ elif field_schema.type in ['union', 'error_union']:
+ return self._read_default_value(field_schema.schemas[0], default_value)
+ elif field_schema.type == 'record':
+ read_record = {}
+ for field in field_schema.fields:
+ json_val = default_value.get(field.name)
+ if json_val is None:
+ json_val = field.default
+ field_val = self._read_default_value(field.type, json_val)
+ read_record[field.name] = field_val
+ return read_record
+ else:
+ fail_msg = 'Unknown type: %s' % field_schema.type
+ raise schema.AvroException(fail_msg)
- def skip_enum(self, writers_schema, decoder):
- return decoder.skip_int()
-
- def read_array(self, writers_schema, readers_schema, decoder):
- """
- Arrays are encoded as a series of blocks.
-
- Each block consists of a long count value,
- followed by that many array items.
- A block with count zero indicates the end of the array.
- Each item is encoded per the array's item schema.
-
- If a block's count is negative,
- then the count is followed immediately by a long block size,
- indicating the number of bytes in the block.
- The actual count in this case
- is the absolute value of the count written.
- """
- read_items = []
- block_count = decoder.read_long()
- while block_count != 0:
- if block_count < 0:
- block_count = -block_count
- block_size = decoder.read_long()
- for i in range(block_count):
- read_items.append(self.read_data(writers_schema.items,
- readers_schema.items, decoder))
- block_count = decoder.read_long()
- return read_items
-
- def skip_array(self, writers_schema, decoder):
- block_count = decoder.read_long()
- while block_count != 0:
- if block_count < 0:
- block_size = decoder.read_long()
- decoder.skip(block_size)
- else:
- for i in range(block_count):
- self.skip_data(writers_schema.items, decoder)
- block_count = decoder.read_long()
-
- def read_map(self, writers_schema, readers_schema, decoder):
- """
- Maps are encoded as a series of blocks.
-
- Each block consists of a long count value,
- followed by that many key/value pairs.
- A block with count zero indicates the end of the map.
- Each item is encoded per the map's value schema.
-
- If a block's count is negative,
- then the count is followed immediately by a long block size,
- indicating the number of bytes in the block.
- The actual count in this case
- is the absolute value of the count written.
- """
- read_items = {}
- block_count = decoder.read_long()
- while block_count != 0:
- if block_count < 0:
- block_count = -block_count
- block_size = decoder.read_long()
- for i in range(block_count):
- key = decoder.read_utf8()
- read_items[key] = self.read_data(writers_schema.values,
- readers_schema.values, decoder)
- block_count = decoder.read_long()
- return read_items
-
- def skip_map(self, writers_schema, decoder):
- block_count = decoder.read_long()
- while block_count != 0:
- if block_count < 0:
- block_size = decoder.read_long()
- decoder.skip(block_size)
- else:
- for i in range(block_count):
- decoder.skip_utf8()
- self.skip_data(writers_schema.values, decoder)
- block_count = decoder.read_long()
-
- def read_union(self, writers_schema, readers_schema, decoder):
- """
- A union is encoded by first writing a long value indicating
- the zero-based position within the union of the schema of its value.
- The value is then encoded per the indicated schema within the union.
- """
- # schema resolution
- index_of_schema = int(decoder.read_long())
- if index_of_schema >= len(writers_schema.schemas):
- fail_msg = "Can't access branch index %d for union with %d branches"\
- % (index_of_schema, len(writers_schema.schemas))
- raise SchemaResolutionException(fail_msg, writers_schema, readers_schema)
- selected_writers_schema = writers_schema.schemas[index_of_schema]
-
- # read data
- return self.read_data(selected_writers_schema, readers_schema, decoder)
-
- def skip_union(self, writers_schema, decoder):
- index_of_schema = int(decoder.read_long())
- if index_of_schema >= len(writers_schema.schemas):
- fail_msg = "Can't access branch index %d for union with %d branches"\
- % (index_of_schema, len(writers_schema.schemas))
- raise SchemaResolutionException(fail_msg, writers_schema)
- return self.skip_data(writers_schema.schemas[index_of_schema], decoder)
-
- def read_record(self, writers_schema, readers_schema, decoder):
- """
- A record is encoded by encoding the values of its fields
- in the order that they are declared. In other words, a record
- is encoded as just the concatenation of the encodings of its fields.
- Field values are encoded per their schema.
-
- Schema Resolution:
- * the ordering of fields may be different: fields are matched by name.
- * schemas for fields with the same name in both records are resolved
- recursively.
- * if the writer's record contains a field with a name not present in the
- reader's record, the writer's value for that field is ignored.
- * if the reader's record schema has a field that contains a default value,
- and writer's schema does not have a field with the same name, then the
- reader should use the default value from its field.
- * if the reader's record schema has a field with no default value, and
- writer's schema does not have a field with the same name, then the
- field's value is unset.
- """
- # schema resolution
- readers_fields_dict = readers_schema.fields_dict
- read_record = {}
- for field in writers_schema.fields:
- readers_field = readers_fields_dict.get(field.name)
- if readers_field is not None:
- field_val = self.read_data(field.type, readers_field.type, decoder)
- read_record[field.name] = field_val
- else:
- self.skip_data(field.type, decoder)
-
- # fill in default values
- if len(readers_fields_dict) > len(read_record):
- writers_fields_dict = writers_schema.fields_dict
- for field_name, field in readers_fields_dict.items():
- if field_name not in writers_fields_dict:
- if field.has_default:
- field_val = self._read_default_value(field.type, field.default)
- read_record[field.name] = field_val
- else:
- fail_msg = 'No default value for field %s' % field_name
- raise SchemaResolutionException(fail_msg, writers_schema,
- readers_schema)
- return read_record
-
- def skip_record(self, writers_schema, decoder):
- for field in writers_schema.fields:
- self.skip_data(field.type, decoder)
-
- def _read_default_value(self, field_schema, default_value):
- """
- Basically a JSON Decoder?
- """
- if field_schema.type == 'null':
- return None
- elif field_schema.type == 'boolean':
- return bool(default_value)
- elif field_schema.type == 'int':
- return int(default_value)
- elif field_schema.type == 'long':
- return long(default_value)
- elif field_schema.type in ['float', 'double']:
- return float(default_value)
- elif field_schema.type in ['enum', 'fixed', 'string', 'bytes']:
- return default_value
- elif field_schema.type == 'array':
- read_array = []
- for json_val in default_value:
- item_val = self._read_default_value(field_schema.items, json_val)
- read_array.append(item_val)
- return read_array
- elif field_schema.type == 'map':
- read_map = {}
- for key, json_val in default_value.items():
- map_val = self._read_default_value(field_schema.values, json_val)
- read_map[key] = map_val
- return read_map
- elif field_schema.type in ['union', 'error_union']:
- return self._read_default_value(field_schema.schemas[0], default_value)
- elif field_schema.type == 'record':
- read_record = {}
- for field in field_schema.fields:
- json_val = default_value.get(field.name)
- if json_val is None: json_val = field.default
- field_val = self._read_default_value(field.type, json_val)
- read_record[field.name] = field_val
- return read_record
- else:
- fail_msg = 'Unknown type: %s' % field_schema.type
- raise schema.AvroException(fail_msg)
class DatumWriter(object):
- """DatumWriter for generic python objects."""
- def __init__(self, writers_schema=None):
- self._writers_schema = writers_schema
-
- # read/write properties
- def set_writers_schema(self, writers_schema):
- self._writers_schema = writers_schema
- writers_schema = property(lambda self: self._writers_schema,
- set_writers_schema)
-
- def write(self, datum, encoder):
- if not validate(self.writers_schema, datum):
- raise AvroTypeException(self.writers_schema, datum)
- self.write_data(self.writers_schema, datum, encoder)
-
- def write_data(self, writers_schema, datum, encoder):
- # function dispatch to write datum
- logical_type = getattr(writers_schema, 'logical_type', None)
- if writers_schema.type == 'null':
- encoder.write_null(datum)
- elif writers_schema.type == 'boolean':
- encoder.write_boolean(datum)
- elif writers_schema.type == 'string':
- encoder.write_utf8(datum)
- elif writers_schema.type == 'int':
- if logical_type == constants.DATE:
- encoder.write_date_int(datum)
- elif logical_type == constants.TIME_MILLIS:
- encoder.write_time_millis_int(datum)
- else:
- encoder.write_int(datum)
- elif writers_schema.type == 'long':
- if logical_type == constants.TIME_MICROS:
- encoder.write_time_micros_long(datum)
- elif logical_type == constants.TIMESTAMP_MILLIS:
- encoder.write_timestamp_millis_long(datum)
- elif logical_type == constants.TIMESTAMP_MICROS:
- encoder.write_timestamp_micros_long(datum)
- else:
- encoder.write_long(datum)
- elif writers_schema.type == 'float':
- encoder.write_float(datum)
- elif writers_schema.type == 'double':
- encoder.write_double(datum)
- elif writers_schema.type == 'bytes':
- if logical_type == 'decimal':
- encoder.write_decimal_bytes(datum, writers_schema.get_prop('scale'))
- else:
- encoder.write_bytes(datum)
- elif writers_schema.type == 'fixed':
- if logical_type == 'decimal':
- encoder.write_decimal_fixed(
- datum,
- writers_schema.get_prop('scale'),
- writers_schema.get_prop('size')
- )
- else:
- self.write_fixed(writers_schema, datum, encoder)
- elif writers_schema.type == 'enum':
- self.write_enum(writers_schema, datum, encoder)
- elif writers_schema.type == 'array':
- self.write_array(writers_schema, datum, encoder)
- elif writers_schema.type == 'map':
- self.write_map(writers_schema, datum, encoder)
- elif writers_schema.type in ['union', 'error_union']:
- self.write_union(writers_schema, datum, encoder)
- elif writers_schema.type in ['record', 'error', 'request']:
- self.write_record(writers_schema, datum, encoder)
- else:
- fail_msg = 'Unknown type: %s' % writers_schema.type
- raise schema.AvroException(fail_msg)
-
- def write_fixed(self, writers_schema, datum, encoder):
- """
- Fixed instances are encoded using the number of bytes declared
- in the schema.
- """
- encoder.write(datum)
-
- def write_enum(self, writers_schema, datum, encoder):
- """
- An enum is encoded by a int, representing the zero-based position
- of the symbol in the schema.
- """
- index_of_datum = writers_schema.symbols.index(datum)
- encoder.write_int(index_of_datum)
-
- def write_array(self, writers_schema, datum, encoder):
- """
- Arrays are encoded as a series of blocks.
-
- Each block consists of a long count value,
- followed by that many array items.
- A block with count zero indicates the end of the array.
- Each item is encoded per the array's item schema.
-
- If a block's count is negative,
- then the count is followed immediately by a long block size,
- indicating the number of bytes in the block.
- The actual count in this case
- is the absolute value of the count written.
- """
- if len(datum) > 0:
- encoder.write_long(len(datum))
- for item in datum:
- self.write_data(writers_schema.items, item, encoder)
- encoder.write_long(0)
-
- def write_map(self, writers_schema, datum, encoder):
- """
- Maps are encoded as a series of blocks.
-
- Each block consists of a long count value,
- followed by that many key/value pairs.
- A block with count zero indicates the end of the map.
- Each item is encoded per the map's value schema.
-
- If a block's count is negative,
- then the count is followed immediately by a long block size,
- indicating the number of bytes in the block.
- The actual count in this case
- is the absolute value of the count written.
- """
- if len(datum) > 0:
- encoder.write_long(len(datum))
- for key, val in datum.items():
- encoder.write_utf8(key)
- self.write_data(writers_schema.values, val, encoder)
- encoder.write_long(0)
-
- def write_union(self, writers_schema, datum, encoder):
- """
- A union is encoded by first writing a long value indicating
- the zero-based position within the union of the schema of its value.
- The value is then encoded per the indicated schema within the union.
- """
- # resolve union
- index_of_schema = -1
- for i, candidate_schema in enumerate(writers_schema.schemas):
- if validate(candidate_schema, datum):
- index_of_schema = i
- if index_of_schema < 0: raise AvroTypeException(writers_schema, datum)
-
- # write data
- encoder.write_long(index_of_schema)
- self.write_data(writers_schema.schemas[index_of_schema], datum, encoder)
-
- def write_record(self, writers_schema, datum, encoder):
- """
- A record is encoded by encoding the values of its fields
- in the order that they are declared. In other words, a record
- is encoded as just the concatenation of the encodings of its fields.
- Field values are encoded per their schema.
- """
- for field in writers_schema.fields:
- self.write_data(field.type, datum.get(field.name), encoder)
+ """DatumWriter for generic python objects."""
+
+ def __init__(self, writers_schema=None):
+ self._writers_schema = writers_schema
+
+ # read/write properties
+ def set_writers_schema(self, writers_schema):
+ self._writers_schema = writers_schema
+ writers_schema = property(lambda self: self._writers_schema,
+ set_writers_schema)
+
+ def write(self, datum, encoder):
+ if not validate(self.writers_schema, datum):
+ raise AvroTypeException(self.writers_schema, datum)
+ self.write_data(self.writers_schema, datum, encoder)
+
+ def write_data(self, writers_schema, datum, encoder):
+ # function dispatch to write datum
+ logical_type = getattr(writers_schema, 'logical_type', None)
+ if writers_schema.type == 'null':
+ encoder.write_null(datum)
+ elif writers_schema.type == 'boolean':
+ encoder.write_boolean(datum)
+ elif writers_schema.type == 'string':
+ encoder.write_utf8(datum)
+ elif writers_schema.type == 'int':
+ if logical_type == constants.DATE:
+ encoder.write_date_int(datum)
+ elif logical_type == constants.TIME_MILLIS:
+ encoder.write_time_millis_int(datum)
+ else:
+ encoder.write_int(datum)
+ elif writers_schema.type == 'long':
+ if logical_type == constants.TIME_MICROS:
+ encoder.write_time_micros_long(datum)
+ elif logical_type == constants.TIMESTAMP_MILLIS:
+ encoder.write_timestamp_millis_long(datum)
+ elif logical_type == constants.TIMESTAMP_MICROS:
+ encoder.write_timestamp_micros_long(datum)
+ else:
+ encoder.write_long(datum)
+ elif writers_schema.type == 'float':
+ encoder.write_float(datum)
+ elif writers_schema.type == 'double':
+ encoder.write_double(datum)
+ elif writers_schema.type == 'bytes':
+ if logical_type == 'decimal':
+ encoder.write_decimal_bytes(datum, writers_schema.get_prop('scale'))
+ else:
+ encoder.write_bytes(datum)
+ elif writers_schema.type == 'fixed':
+ if logical_type == 'decimal':
+ encoder.write_decimal_fixed(
+ datum,
+ writers_schema.get_prop('scale'),
+ writers_schema.get_prop('size')
+ )
+ else:
+ self.write_fixed(writers_schema, datum, encoder)
+ elif writers_schema.type == 'enum':
+ self.write_enum(writers_schema, datum, encoder)
+ elif writers_schema.type == 'array':
+ self.write_array(writers_schema, datum, encoder)
+ elif writers_schema.type == 'map':
+ self.write_map(writers_schema, datum, encoder)
+ elif writers_schema.type in ['union', 'error_union']:
+ self.write_union(writers_schema, datum, encoder)
+ elif writers_schema.type in ['record', 'error', 'request']:
+ self.write_record(writers_schema, datum, encoder)
+ else:
+ fail_msg = 'Unknown type: %s' % writers_schema.type
+ raise schema.AvroException(fail_msg)
+
+ def write_fixed(self, writers_schema, datum, encoder):
+ """
+ Fixed instances are encoded using the number of bytes declared
+ in the schema.
+ """
+ encoder.write(datum)
+
+ def write_enum(self, writers_schema, datum, encoder):
+ """
+ An enum is encoded by a int, representing the zero-based position
+ of the symbol in the schema.
+ """
+ index_of_datum = writers_schema.symbols.index(datum)
+ encoder.write_int(index_of_datum)
+
+ def write_array(self, writers_schema, datum, encoder):
+ """
+ Arrays are encoded as a series of blocks.
+
+ Each block consists of a long count value,
+ followed by that many array items.
+ A block with count zero indicates the end of the array.
+ Each item is encoded per the array's item schema.
+
+ If a block's count is negative,
+ then the count is followed immediately by a long block size,
+ indicating the number of bytes in the block.
+ The actual count in this case
+ is the absolute value of the count written.
+ """
+ if len(datum) > 0:
+ encoder.write_long(len(datum))
+ for item in datum:
+ self.write_data(writers_schema.items, item, encoder)
+ encoder.write_long(0)
+
+ def write_map(self, writers_schema, datum, encoder):
+ """
+ Maps are encoded as a series of blocks.
+
+ Each block consists of a long count value,
+ followed by that many key/value pairs.
+ A block with count zero indicates the end of the map.
+ Each item is encoded per the map's value schema.
+
+ If a block's count is negative,
+ then the count is followed immediately by a long block size,
+ indicating the number of bytes in the block.
+ The actual count in this case
+ is the absolute value of the count written.
+ """
+ if len(datum) > 0:
+ encoder.write_long(len(datum))
+ for key, val in datum.items():
+ encoder.write_utf8(key)
+ self.write_data(writers_schema.values, val, encoder)
+ encoder.write_long(0)
+
+ def write_union(self, writers_schema, datum, encoder):
+ """
+ A union is encoded by first writing a long value indicating
+ the zero-based position within the union of the schema of its value.
+ The value is then encoded per the indicated schema within the union.
+ """
+ # resolve union
+ index_of_schema = -1
+ for i, candidate_schema in enumerate(writers_schema.schemas):
+ if validate(candidate_schema, datum):
+ index_of_schema = i
+ if index_of_schema < 0:
+ raise AvroTypeException(writers_schema, datum)
+
+ # write data
+ encoder.write_long(index_of_schema)
+ self.write_data(writers_schema.schemas[index_of_schema], datum, encoder)
+
+ def write_record(self, writers_schema, datum, encoder):
+ """
+ A record is encoded by encoding the values of its fields
+ in the order that they are declared. In other words, a record
+ is encoded as just the concatenation of the encodings of its fields.
+ Field values are encoded per their schema.
+ """
+ for field in writers_schema.fields:
+ self.write_data(field.type, datum.get(field.name), encoder)
diff --git a/lang/py/avro/ipc.py b/lang/py/avro/ipc.py
index 9f2d67d..550d696 100644
--- a/lang/py/avro/ipc.py
+++ b/lang/py/avro/ipc.py
@@ -29,21 +29,22 @@ import avro.io
from avro import protocol, schema
try:
- import httplib # type: ignore
+ import httplib # type: ignore
except ImportError:
- import http.client as httplib # type: ignore
+ import http.client as httplib # type: ignore
try:
- unicode
+ unicode
except NameError:
- unicode = str
+ unicode = str
def _load(name):
- dir_path = os.path.dirname(__file__)
- rsrc_path = os.path.join(dir_path, name)
- with open(rsrc_path, 'r') as f:
- return f.read()
+ dir_path = os.path.dirname(__file__)
+ rsrc_path = os.path.join(dir_path, name)
+ with open(rsrc_path, 'r') as f:
+ return f.read()
+
HANDSHAKE_REQUEST_SCHEMA_JSON = _load('HandshakeRequest.avsc')
HANDSHAKE_RESPONSE_SCHEMA_JSON = _load('HandshakeResponse.avsc')
@@ -73,413 +74,428 @@ BUFFER_SIZE = 8192
# Exceptions
#
+
class AvroRemoteException(schema.AvroException):
- """
- Raised when an error message is sent by an Avro requestor or responder.
- """
- def __init__(self, fail_msg=None):
- schema.AvroException.__init__(self, fail_msg)
+ """
+ Raised when an error message is sent by an Avro requestor or responder.
+ """
+
+ def __init__(self, fail_msg=None):
+ schema.AvroException.__init__(self, fail_msg)
+
class ConnectionClosedException(schema.AvroException):
- pass
+ pass
#
# Base IPC Classes (Requestor/Responder)
#
+
class BaseRequestor(object):
- """Base class for the client side of a protocol interaction."""
- def __init__(self, local_protocol, transceiver):
- self._local_protocol = local_protocol
- self._transceiver = transceiver
- self._remote_protocol = None
- self._remote_hash = None
- self._send_protocol = None
-
- # read-only properties
- local_protocol = property(lambda self: self._local_protocol)
- transceiver = property(lambda self: self._transceiver)
-
- # read/write properties
- def set_remote_protocol(self, new_remote_protocol):
- self._remote_protocol = new_remote_protocol
- REMOTE_PROTOCOLS[self.transceiver.remote_name] = self.remote_protocol
- remote_protocol = property(lambda self: self._remote_protocol,
- set_remote_protocol)
-
- def set_remote_hash(self, new_remote_hash):
- self._remote_hash = new_remote_hash
- REMOTE_HASHES[self.transceiver.remote_name] = self.remote_hash
- remote_hash = property(lambda self: self._remote_hash, set_remote_hash)
-
- def set_send_protocol(self, new_send_protocol):
- self._send_protocol = new_send_protocol
- send_protocol = property(lambda self: self._send_protocol, set_send_protocol)
-
- def request(self, message_name, request_datum):
- """
- Writes a request message and reads a response or error message.
- """
- # build handshake and call request
- buffer_writer = io.BytesIO()
- buffer_encoder = avro.io.BinaryEncoder(buffer_writer)
- self.write_handshake_request(buffer_encoder)
- self.write_call_request(message_name, request_datum, buffer_encoder)
-
- # send the handshake and call request; block until call response
- call_request = buffer_writer.getvalue()
- return self.issue_request(call_request, message_name, request_datum)
-
- def write_handshake_request(self, encoder):
- local_hash = self.local_protocol.md5
- remote_name = self.transceiver.remote_name
- remote_hash = REMOTE_HASHES.get(remote_name)
- if remote_hash is None:
- remote_hash = local_hash
- self.remote_protocol = self.local_protocol
- request_datum = {}
- request_datum['clientHash'] = local_hash
- request_datum['serverHash'] = remote_hash
- if self.send_protocol:
- request_datum['clientProtocol'] = unicode(self.local_protocol)
- HANDSHAKE_REQUESTOR_WRITER.write(request_datum, encoder)
-
- def write_call_request(self, message_name, request_datum, encoder):
- """
- The format of a call request is:
- * request metadata, a map with values of type bytes
- * the message name, an Avro string, followed by
- * the message parameters. Parameters are serialized according to
- the message's request declaration.
- """
- # request metadata (not yet implemented)
- request_metadata = {}
- META_WRITER.write(request_metadata, encoder)
-
- # message name
- message = self.local_protocol.messages.get(message_name)
- if message is None:
- raise schema.AvroException('Unknown message: %s' % message_name)
- encoder.write_utf8(message.name)
-
- # message parameters
- self.write_request(message.request, request_datum, encoder)
-
- def write_request(self, request_schema, request_datum, encoder):
- datum_writer = avro.io.DatumWriter(request_schema)
- datum_writer.write(request_datum, encoder)
-
- def read_handshake_response(self, decoder):
- handshake_response = HANDSHAKE_REQUESTOR_READER.read(decoder)
- match = handshake_response.get('match')
- if match == 'BOTH':
- self.send_protocol = False
- return True
- elif match == 'CLIENT':
- if self.send_protocol:
- raise schema.AvroException('Handshake failure.')
- self.remote_protocol = protocol.parse(
- handshake_response.get('serverProtocol'))
- self.remote_hash = handshake_response.get('serverHash')
- self.send_protocol = False
- return True
- elif match == 'NONE':
- if self.send_protocol:
- raise schema.AvroException('Handshake failure.')
- self.remote_protocol = protocol.parse(
- handshake_response.get('serverProtocol'))
- self.remote_hash = handshake_response.get('serverHash')
- self.send_protocol = True
- return False
- else:
- raise schema.AvroException('Unexpected match: %s' % match)
-
- def read_call_response(self, message_name, decoder):
- """
- The format of a call response is:
- * response metadata, a map with values of type bytes
- * a one-byte error flag boolean, followed by either:
- o if the error flag is false,
- the message response, serialized per the message's response schema.
- o if the error flag is true,
- the error, serialized per the message's error union schema.
- """
- # response metadata
- response_metadata = META_READER.read(decoder)
-
- # remote response schema
- remote_message_schema = self.remote_protocol.messages.get(message_name)
- if remote_message_schema is None:
- raise schema.AvroException('Unknown remote message: %s' % message_name)
-
- # local response schema
- local_message_schema = self.local_protocol.messages.get(message_name)
- if local_message_schema is None:
- raise schema.AvroException('Unknown local message: %s' % message_name)
-
- # error flag
- if not decoder.read_boolean():
- writers_schema = remote_message_schema.response
- readers_schema = local_message_schema.response
- return self.read_response(writers_schema, readers_schema, decoder)
- else:
- writers_schema = remote_message_schema.errors
- readers_schema = local_message_schema.errors
- raise self.read_error(writers_schema, readers_schema, decoder)
-
- def read_response(self, writers_schema, readers_schema, decoder):
- datum_reader = avro.io.DatumReader(writers_schema, readers_schema)
- result = datum_reader.read(decoder)
- return result
-
- def read_error(self, writers_schema, readers_schema, decoder):
- datum_reader = avro.io.DatumReader(writers_schema, readers_schema)
- return AvroRemoteException(datum_reader.read(decoder))
+ """Base class for the client side of a protocol interaction."""
+
+ def __init__(self, local_protocol, transceiver):
+ self._local_protocol = local_protocol
+ self._transceiver = transceiver
+ self._remote_protocol = None
+ self._remote_hash = None
+ self._send_protocol = None
+
+ # read-only properties
+ local_protocol = property(lambda self: self._local_protocol)
+ transceiver = property(lambda self: self._transceiver)
+
+ # read/write properties
+ def set_remote_protocol(self, new_remote_protocol):
+ self._remote_protocol = new_remote_protocol
+ REMOTE_PROTOCOLS[self.transceiver.remote_name] = self.remote_protocol
+ remote_protocol = property(lambda self: self._remote_protocol,
+ set_remote_protocol)
+
+ def set_remote_hash(self, new_remote_hash):
+ self._remote_hash = new_remote_hash
+ REMOTE_HASHES[self.transceiver.remote_name] = self.remote_hash
+ remote_hash = property(lambda self: self._remote_hash, set_remote_hash)
+
+ def set_send_protocol(self, new_send_protocol):
+ self._send_protocol = new_send_protocol
+ send_protocol = property(lambda self: self._send_protocol, set_send_protocol)
+
+ def request(self, message_name, request_datum):
+ """
+ Writes a request message and reads a response or error message.
+ """
+ # build handshake and call request
+ buffer_writer = io.BytesIO()
+ buffer_encoder = avro.io.BinaryEncoder(buffer_writer)
+ self.write_handshake_request(buffer_encoder)
+ self.write_call_request(message_name, request_datum, buffer_encoder)
+
+ # send the handshake and call request; block until call response
+ call_request = buffer_writer.getvalue()
+ return self.issue_request(call_request, message_name, request_datum)
+
+ def write_handshake_request(self, encoder):
+ local_hash = self.local_protocol.md5
+ remote_name = self.transceiver.remote_name
+ remote_hash = REMOTE_HASHES.get(remote_name)
+ if remote_hash is None:
+ remote_hash = local_hash
+ self.remote_protocol = self.local_protocol
+ request_datum = {}
+ request_datum['clientHash'] = local_hash
+ request_datum['serverHash'] = remote_hash
+ if self.send_protocol:
+ request_datum['clientProtocol'] = unicode(self.local_protocol)
+ HANDSHAKE_REQUESTOR_WRITER.write(request_datum, encoder)
+
+ def write_call_request(self, message_name, request_datum, encoder):
+ """
+ The format of a call request is:
+ * request metadata, a map with values of type bytes
+ * the message name, an Avro string, followed by
+ * the message parameters. Parameters are serialized according to
+ the message's request declaration.
+ """
+ # request metadata (not yet implemented)
+ request_metadata = {}
+ META_WRITER.write(request_metadata, encoder)
+
+ # message name
+ message = self.local_protocol.messages.get(message_name)
+ if message is None:
+ raise schema.AvroException('Unknown message: %s' % message_name)
+ encoder.write_utf8(message.name)
+
+ # message parameters
+ self.write_request(message.request, request_datum, encoder)
+
+ def write_request(self, request_schema, request_datum, encoder):
+ datum_writer = avro.io.DatumWriter(request_schema)
+ datum_writer.write(request_datum, encoder)
+
+ def read_handshake_response(self, decoder):
+ handshake_response = HANDSHAKE_REQUESTOR_READER.read(decoder)
+ match = handshake_response.get('match')
+ if match == 'BOTH':
+ self.send_protocol = False
+ return True
+ elif match == 'CLIENT':
+ if self.send_protocol:
+ raise schema.AvroException('Handshake failure.')
+ self.remote_protocol = protocol.parse(
+ handshake_response.get('serverProtocol'))
+ self.remote_hash = handshake_response.get('serverHash')
+ self.send_protocol = False
+ return True
+ elif match == 'NONE':
+ if self.send_protocol:
+ raise schema.AvroException('Handshake failure.')
+ self.remote_protocol = protocol.parse(
+ handshake_response.get('serverProtocol'))
+ self.remote_hash = handshake_response.get('serverHash')
+ self.send_protocol = True
+ return False
+ else:
+ raise schema.AvroException('Unexpected match: %s' % match)
+
+ def read_call_response(self, message_name, decoder):
+ """
+ The format of a call response is:
+ * response metadata, a map with values of type bytes
+ * a one-byte error flag boolean, followed by either:
+ o if the error flag is false,
+ the message response, serialized per the message's response schema.
+ o if the error flag is true,
+ the error, serialized per the message's error union schema.
+ """
+ # response metadata
+ response_metadata = META_READER.read(decoder)
+
+ # remote response schema
+ remote_message_schema = self.remote_protocol.messages.get(message_name)
+ if remote_message_schema is None:
+ raise schema.AvroException('Unknown remote message: %s' % message_name)
+
+ # local response schema
+ local_message_schema = self.local_protocol.messages.get(message_name)
+ if local_message_schema is None:
+ raise schema.AvroException('Unknown local message: %s' % message_name)
+
+ # error flag
+ if not decoder.read_boolean():
+ writers_schema = remote_message_schema.response
+ readers_schema = local_message_schema.response
+ return self.read_response(writers_schema, readers_schema, decoder)
+ else:
+ writers_schema = remote_message_schema.errors
+ readers_schema = local_message_schema.errors
+ raise self.read_error(writers_schema, readers_schema, decoder)
+
+ def read_response(self, writers_schema, readers_schema, decoder):
+ datum_reader = avro.io.DatumReader(writers_schema, readers_schema)
+ result = datum_reader.read(decoder)
+ return result
+
+ def read_error(self, writers_schema, readers_schema, decoder):
+ datum_reader = avro.io.DatumReader(writers_schema, readers_schema)
+ return AvroRemoteException(datum_reader.read(decoder))
+
class Requestor(BaseRequestor):
- def issue_request(self, call_request, message_name, request_datum):
- call_response = self.transceiver.transceive(call_request)
+ def issue_request(self, call_request, message_name, request_datum):
+ call_response = self.transceiver.transceive(call_request)
+
+ # process the handshake and call response
+ buffer_decoder = avro.io.BinaryDecoder(io.BytesIO(call_response))
+ call_response_exists = self.read_handshake_response(buffer_decoder)
+ if call_response_exists:
+ return self.read_call_response(message_name, buffer_decoder)
+ return self.request(message_name, request_datum)
- # process the handshake and call response
- buffer_decoder = avro.io.BinaryDecoder(io.BytesIO(call_response))
- call_response_exists = self.read_handshake_response(buffer_decoder)
- if call_response_exists:
- return self.read_call_response(message_name, buffer_decoder)
- return self.request(message_name, request_datum)
class Responder(object):
- """Base class for the server side of a protocol interaction."""
- def __init__(self, local_protocol):
- self._local_protocol = local_protocol
- self._local_hash = self.local_protocol.md5
- self._protocol_cache = {}
- self.set_protocol_cache(self.local_hash, self.local_protocol)
-
- # read-only properties
- local_protocol = property(lambda self: self._local_protocol)
- local_hash = property(lambda self: self._local_hash)
- protocol_cache = property(lambda self: self._protocol_cache)
-
- # utility functions to manipulate protocol cache
- def get_protocol_cache(self, hash):
- return self.protocol_cache.get(hash)
- def set_protocol_cache(self, hash, protocol):
- self.protocol_cache[hash] = protocol
-
- def respond(self, call_request):
- """
- Called by a server to deserialize a request, compute and serialize
- a response or error. Compare to 'handle()' in Thrift.
- """
- buffer_reader = io.BytesIO(call_request)
- buffer_decoder = avro.io.BinaryDecoder(buffer_reader)
- buffer_writer = io.BytesIO()
- buffer_encoder = avro.io.BinaryEncoder(buffer_writer)
- error = None
- response_metadata = {}
-
- try:
- remote_protocol = self.process_handshake(buffer_decoder, buffer_encoder)
- # handshake failure
- if remote_protocol is None:
+ """Base class for the server side of a protocol interaction."""
+
+ def __init__(self, local_protocol):
+ self._local_protocol = local_protocol
+ self._local_hash = self.local_protocol.md5
+ self._protocol_cache = {}
+ self.set_protocol_cache(self.local_hash, self.local_protocol)
+
+ # read-only properties
+ local_protocol = property(lambda self: self._local_protocol)
+ local_hash = property(lambda self: self._local_hash)
+ protocol_cache = property(lambda self: self._protocol_cache)
+
+ # utility functions to manipulate protocol cache
+ def get_protocol_cache(self, hash):
+ return self.protocol_cache.get(hash)
+
+ def set_protocol_cache(self, hash, protocol):
+ self.protocol_cache[hash] = protocol
+
+ def respond(self, call_request):
+ """
+ Called by a server to deserialize a request, compute and serialize
+ a response or error. Compare to 'handle()' in Thrift.
+ """
+ buffer_reader = io.BytesIO(call_request)
+ buffer_decoder = avro.io.BinaryDecoder(buffer_reader)
+ buffer_writer = io.BytesIO()
+ buffer_encoder = avro.io.BinaryEncoder(buffer_writer)
+ error = None
+ response_metadata = {}
+
+ try:
+ remote_protocol = self.process_handshake(buffer_decoder, buffer_encoder)
+ # handshake failure
+ if remote_protocol is None:
+ return buffer_writer.getvalue()
+
+ # read request using remote protocol
+ request_metadata = META_READER.read(buffer_decoder)
+ remote_message_name = buffer_decoder.read_utf8()
+
+ # get remote and local request schemas so we can do
+ # schema resolution (one fine day)
+ remote_message = remote_protocol.messages.get(remote_message_name)
+ if remote_message is None:
+ fail_msg = 'Unknown remote message: %s' % remote_message_name
+ raise schema.AvroException(fail_msg)
+ local_message = self.local_protocol.messages.get(remote_message_name)
+ if local_message is None:
+ fail_msg = 'Unknown local message: %s' % remote_message_name
+ raise schema.AvroException(fail_msg)
+ writers_schema = remote_message.request
+ readers_schema = local_message.request
+ request = self.read_request(writers_schema, readers_schema,
+ buffer_decoder)
+
+ # perform server logic
+ try:
+ response = self.invoke(local_message, request)
+ except AvroRemoteException as e:
+ error = e
+ except Exception as e:
+ error = AvroRemoteException(unicode(e))
+
+ # write response using local protocol
+ META_WRITER.write(response_metadata, buffer_encoder)
+ buffer_encoder.write_boolean(error is not None)
+ if error is None:
+ writers_schema = local_message.response
+ self.write_response(writers_schema, response, buffer_encoder)
+ else:
+ writers_schema = local_message.errors
+ self.write_error(writers_schema, error, buffer_encoder)
+ except schema.AvroException as e:
+ error = AvroRemoteException(unicode(e))
+ buffer_encoder = avro.io.BinaryEncoder(io.BytesIO())
+ META_WRITER.write(response_metadata, buffer_encoder)
+ buffer_encoder.write_boolean(True)
+ self.write_error(SYSTEM_ERROR_SCHEMA, error, buffer_encoder)
return buffer_writer.getvalue()
- # read request using remote protocol
- request_metadata = META_READER.read(buffer_decoder)
- remote_message_name = buffer_decoder.read_utf8()
-
- # get remote and local request schemas so we can do
- # schema resolution (one fine day)
- remote_message = remote_protocol.messages.get(remote_message_name)
- if remote_message is None:
- fail_msg = 'Unknown remote message: %s' % remote_message_name
- raise schema.AvroException(fail_msg)
- local_message = self.local_protocol.messages.get(remote_message_name)
- if local_message is None:
- fail_msg = 'Unknown local message: %s' % remote_message_name
- raise schema.AvroException(fail_msg)
- writers_schema = remote_message.request
- readers_schema = local_message.request
- request = self.read_request(writers_schema, readers_schema,
- buffer_decoder)
-
- # perform server logic
- try:
- response = self.invoke(local_message, request)
- except AvroRemoteException as e:
- error = e
- except Exception as e:
- error = AvroRemoteException(unicode(e))
-
- # write response using local protocol
- META_WRITER.write(response_metadata, buffer_encoder)
- buffer_encoder.write_boolean(error is not None)
- if error is None:
- writers_schema = local_message.response
- self.write_response(writers_schema, response, buffer_encoder)
- else:
- writers_schema = local_message.errors
- self.write_error(writers_schema, error, buffer_encoder)
- except schema.AvroException as e:
- error = AvroRemoteException(unicode(e))
- buffer_encoder = avro.io.BinaryEncoder(io.BytesIO())
- META_WRITER.write(response_metadata, buffer_encoder)
- buffer_encoder.write_boolean(True)
- self.write_error(SYSTEM_ERROR_SCHEMA, error, buffer_encoder)
- return buffer_writer.getvalue()
-
- def process_handshake(self, decoder, encoder):
- handshake_request = HANDSHAKE_RESPONDER_READER.read(decoder)
- handshake_response = {}
-
- # determine the remote protocol
- client_hash = handshake_request.get('clientHash')
- client_protocol = handshake_request.get('clientProtocol')
- remote_protocol = self.get_protocol_cache(client_hash)
- if remote_protocol is None and client_protocol is not None:
- remote_protocol = protocol.parse(client_protocol)
- self.set_protocol_cache(client_hash, remote_protocol)
-
- # evaluate remote's guess of the local protocol
- server_hash = handshake_request.get('serverHash')
- if self.local_hash == server_hash:
- if remote_protocol is None:
- handshake_response['match'] = 'NONE'
- else:
- handshake_response['match'] = 'BOTH'
- else:
- if remote_protocol is None:
- handshake_response['match'] = 'NONE'
- else:
- handshake_response['match'] = 'CLIENT'
-
- if handshake_response['match'] != 'BOTH':
- handshake_response['serverProtocol'] = unicode(self.local_protocol)
- handshake_response['serverHash'] = self.local_hash
-
- HANDSHAKE_RESPONDER_WRITER.write(handshake_response, encoder)
- return remote_protocol
-
- def invoke(self, local_message, request):
- """
- Aactual work done by server: cf. handler in thrift.
- """
- pass
-
- def read_request(self, writers_schema, readers_schema, decoder):
- datum_reader = avro.io.DatumReader(writers_schema, readers_schema)
- return datum_reader.read(decoder)
-
- def write_response(self, writers_schema, response_datum, encoder):
- datum_writer = avro.io.DatumWriter(writers_schema)
- datum_writer.write(response_datum, encoder)
-
- def write_error(self, writers_schema, error_exception, encoder):
- datum_writer = avro.io.DatumWriter(writers_schema)
- datum_writer.write(unicode(error_exception), encoder)
+ def process_handshake(self, decoder, encoder):
+ handshake_request = HANDSHAKE_RESPONDER_READER.read(decoder)
+ handshake_response = {}
+
+ # determine the remote protocol
+ client_hash = handshake_request.get('clientHash')
+ client_protocol = handshake_request.get('clientProtocol')
+ remote_protocol = self.get_protocol_cache(client_hash)
+ if remote_protocol is None and client_protocol is not None:
+ remote_protocol = protocol.parse(client_protocol)
+ self.set_protocol_cache(client_hash, remote_protocol)
+
+ # evaluate remote's guess of the local protocol
+ server_hash = handshake_request.get('serverHash')
+ if self.local_hash == server_hash:
+ if remote_protocol is None:
+ handshake_response['match'] = 'NONE'
+ else:
+ handshake_response['match'] = 'BOTH'
+ else:
+ if remote_protocol is None:
+ handshake_response['match'] = 'NONE'
+ else:
+ handshake_response['match'] = 'CLIENT'
+
+ if handshake_response['match'] != 'BOTH':
+ handshake_response['serverProtocol'] = unicode(self.local_protocol)
+ handshake_response['serverHash'] = self.local_hash
+
+ HANDSHAKE_RESPONDER_WRITER.write(handshake_response, encoder)
+ return remote_protocol
+
+ def invoke(self, local_message, request):
+ """
+ Aactual work done by server: cf. handler in thrift.
+ """
+ pass
+
+ def read_request(self, writers_schema, readers_schema, decoder):
+ datum_reader = avro.io.DatumReader(writers_schema, readers_schema)
+ return datum_reader.read(decoder)
+
+ def write_response(self, writers_schema, response_datum, encoder):
+ datum_writer = avro.io.DatumWriter(writers_schema)
+ datum_writer.write(response_datum, encoder)
+
+ def write_error(self, writers_schema, error_exception, encoder):
+ datum_writer = avro.io.DatumWriter(writers_schema)
+ datum_writer.write(unicode(error_exception), encoder)
#
# Utility classes
#
+
class FramedReader(object):
- """Wrapper around a file-like object to read framed data."""
- def __init__(self, reader):
- self._reader = reader
-
- # read-only properties
- reader = property(lambda self: self._reader)
-
- def read_framed_message(self):
- message = []
- while True:
- buffer = io.BytesIO()
- buffer_length = self._read_buffer_length()
- if buffer_length == 0:
- return b''.join(message)
- while buffer.tell() < buffer_length:
- chunk = self.reader.read(buffer_length - buffer.tell())
- if chunk == '':
- raise ConnectionClosedException("Reader read 0 bytes.")
- buffer.write(chunk)
- message.append(buffer.getvalue())
-
- def _read_buffer_length(self):
- read = self.reader.read(BUFFER_HEADER_LENGTH)
- if read == '':
- raise ConnectionClosedException("Reader read 0 bytes.")
- return BIG_ENDIAN_INT_STRUCT.unpack(read)[0]
+ """Wrapper around a file-like object to read framed data."""
+
+ def __init__(self, reader):
+ self._reader = reader
+
+ # read-only properties
+ reader = property(lambda self: self._reader)
+
+ def read_framed_message(self):
+ message = []
+ while True:
+ buffer = io.BytesIO()
+ buffer_length = self._read_buffer_length()
+ if buffer_length == 0:
+ return b''.join(message)
+ while buffer.tell() < buffer_length:
+ chunk = self.reader.read(buffer_length - buffer.tell())
+ if chunk == '':
+ raise ConnectionClosedException("Reader read 0 bytes.")
+ buffer.write(chunk)
+ message.append(buffer.getvalue())
+
+ def _read_buffer_length(self):
+ read = self.reader.read(BUFFER_HEADER_LENGTH)
+ if read == '':
+ raise ConnectionClosedException("Reader read 0 bytes.")
+ return BIG_ENDIAN_INT_STRUCT.unpack(read)[0]
+
class FramedWriter(object):
- """Wrapper around a file-like object to write framed data."""
- def __init__(self, writer):
- self._writer = writer
-
- # read-only properties
- writer = property(lambda self: self._writer)
-
- def write_framed_message(self, message):
- message_length = len(message)
- total_bytes_sent = 0
- while message_length - total_bytes_sent > 0:
- if message_length - total_bytes_sent > BUFFER_SIZE:
- buffer_length = BUFFER_SIZE
- else:
- buffer_length = message_length - total_bytes_sent
- self.write_buffer(message[total_bytes_sent:
- (total_bytes_sent + buffer_length)])
- total_bytes_sent += buffer_length
- # A message is always terminated by a zero-length buffer.
- self.write_buffer_length(0)
-
- def write_buffer(self, chunk):
- buffer_length = len(chunk)
- self.write_buffer_length(buffer_length)
- self.writer.write(chunk)
-
- def write_buffer_length(self, n):
- self.writer.write(BIG_ENDIAN_INT_STRUCT.pack(n))
+ """Wrapper around a file-like object to write framed data."""
+
+ def __init__(self, writer):
+ self._writer = writer
+
+ # read-only properties
+ writer = property(lambda self: self._writer)
+
+ def write_framed_message(self, message):
+ message_length = len(message)
+ total_bytes_sent = 0
+ while message_length - total_bytes_sent > 0:
+ if message_length - total_bytes_sent > BUFFER_SIZE:
+ buffer_length = BUFFER_SIZE
+ else:
+ buffer_length = message_length - total_bytes_sent
+ self.write_buffer(message[total_bytes_sent:
+ (total_bytes_sent + buffer_length)])
+ total_bytes_sent += buffer_length
+ # A message is always terminated by a zero-length buffer.
+ self.write_buffer_length(0)
+
+ def write_buffer(self, chunk):
+ buffer_length = len(chunk)
+ self.write_buffer_length(buffer_length)
+ self.writer.write(chunk)
+
+ def write_buffer_length(self, n):
+ self.writer.write(BIG_ENDIAN_INT_STRUCT.pack(n))
#
# Transceiver Implementations
#
+
class HTTPTransceiver(object):
- """
- A simple HTTP-based transceiver implementation.
- Useful for clients but not for servers
- """
- def __init__(self, host, port, req_resource='/'):
- self.req_resource = req_resource
- self.conn = httplib.HTTPConnection(host, port)
- self.conn.connect()
- self.remote_name = self.conn.sock.getsockname()
-
- def transceive(self, request):
- self.write_framed_message(request)
- result = self.read_framed_message()
- return result
-
- def read_framed_message(self):
- response = self.conn.getresponse()
- response_reader = FramedReader(response)
- framed_message = response_reader.read_framed_message()
- response.read() # ensure we're ready for subsequent requests
- return framed_message
-
- def write_framed_message(self, message):
- req_method = 'POST'
- req_headers = {'Content-Type': 'avro/binary'}
-
- req_body_buffer = FramedWriter(io.BytesIO())
- req_body_buffer.write_framed_message(message)
- req_body = req_body_buffer.writer.getvalue()
-
- self.conn.request(req_method, self.req_resource, req_body, req_headers)
-
- def close(self):
- self.conn.close()
+ """
+ A simple HTTP-based transceiver implementation.
+ Useful for clients but not for servers
+ """
+
+ def __init__(self, host, port, req_resource='/'):
+ self.req_resource = req_resource
+ self.conn = httplib.HTTPConnection(host, port)
+ self.conn.connect()
+ self.remote_name = self.conn.sock.getsockname()
+
+ def transceive(self, request):
+ self.write_framed_message(request)
+ result = self.read_framed_message()
+ return result
+
+ def read_framed_message(self):
+ response = self.conn.getresponse()
+ response_reader = FramedReader(response)
+ framed_message = response_reader.read_framed_message()
+ response.read() # ensure we're ready for subsequent requests
+ return framed_message
+
+ def write_framed_message(self, message):
+ req_method = 'POST'
+ req_headers = {'Content-Type': 'avro/binary'}
+
+ req_body_buffer = FramedWriter(io.BytesIO())
+ req_body_buffer.write_framed_message(message)
+ req_body = req_body_buffer.writer.getvalue()
+
+ self.conn.request(req_method, self.req_resource, req_body, req_headers)
+
+ def close(self):
+ self.conn.close()
#
# Server Implementations (none yet)
diff --git a/lang/py/avro/protocol.py b/lang/py/avro/protocol.py
index e5dc22e..6bcd9b8 100644
--- a/lang/py/avro/protocol.py
+++ b/lang/py/avro/protocol.py
@@ -27,14 +27,14 @@ import json
import avro.schema
try:
- unicode
+ unicode
except NameError:
- unicode = str
+ unicode = str
try:
- basestring # type: ignore
+ basestring # type: ignore
except NameError:
- basestring = (bytes, unicode)
+ basestring = (bytes, unicode)
#
# Constants
@@ -47,207 +47,216 @@ VALID_TYPE_SCHEMA_TYPES = ('enum', 'record', 'error', 'fixed')
# Exceptions
#
+
class ProtocolParseException(avro.schema.AvroException):
- pass
+ pass
#
# Base Classes
#
+
class Protocol(object):
- """An application protocol."""
- def _parse_types(self, types, type_names):
- type_objects = []
- for type in types:
- type_object = avro.schema.make_avsc_object(type, type_names)
- if type_object.type not in VALID_TYPE_SCHEMA_TYPES:
- fail_msg = 'Type %s not an enum, fixed, record, or error.' % type
- raise ProtocolParseException(fail_msg)
- type_objects.append(type_object)
- return type_objects
-
- def _parse_messages(self, messages, names):
- message_objects = {}
- for name, body in messages.items():
- if name in message_objects:
- fail_msg = 'Message name "%s" repeated.' % name
- raise ProtocolParseException(fail_msg)
- try:
- request = body.get('request')
- response = body.get('response')
- errors = body.get('errors')
- except AttributeError:
- fail_msg = 'Message name "%s" has non-object body %s.' % (name, body)
- raise ProtocolParseException(fail_msg)
- message_objects[name] = Message(name, request, response, errors, names)
- return message_objects
-
- def __init__(self, name, namespace=None, types=None, messages=None):
- # Ensure valid ctor args
- if not name:
- fail_msg = 'Protocols must have a non-empty name.'
- raise ProtocolParseException(fail_msg)
- elif not isinstance(name, basestring):
- fail_msg = 'The name property must be a string.'
- raise ProtocolParseException(fail_msg)
- elif not (namespace is None or isinstance(namespace, basestring)):
- fail_msg = 'The namespace property must be a string.'
- raise ProtocolParseException(fail_msg)
- elif not (types is None or isinstance(types, list)):
- fail_msg = 'The types property must be a list.'
- raise ProtocolParseException(fail_msg)
- elif not (messages is None or callable(getattr(messages, 'get', None))):
- fail_msg = 'The messages property must be a JSON object.'
- raise ProtocolParseException(fail_msg)
-
- self._props = {}
- self.set_prop('name', name)
- type_names = avro.schema.Names()
- if namespace is not None:
- self.set_prop('namespace', namespace)
- type_names.default_namespace = namespace
- if types is not None:
- self.set_prop('types', self._parse_types(types, type_names))
- if messages is not None:
- self.set_prop('messages', self._parse_messages(messages, type_names))
- self._md5 = hashlib.md5(str(self).encode()).digest()
-
- # read-only properties
- @property
- def name(self):
- return self.get_prop('name')
-
- @property
- def namespace(self):
- return self.get_prop('namespace')
-
- @property
- def fullname(self):
- return avro.schema.Name(self.name, self.namespace, None).fullname
-
- @property
- def types(self):
- return self.get_prop('types')
-
- @property
- def types_dict(self):
- return {type.name: type for type in self.types}
-
- @property
- def messages(self):
- return self.get_prop('messages')
-
- @property
- def md5(self):
- return self._md5
-
- @property
- def props(self):
- return self._props
-
- # utility functions to manipulate properties dict
- def get_prop(self, key):
- return self.props.get(key)
- def set_prop(self, key, value):
- self.props[key] = value
-
- def to_json(self):
- to_dump = {}
- to_dump['protocol'] = self.name
- names = avro.schema.Names(default_namespace=self.namespace)
- if self.namespace:
- to_dump['namespace'] = self.namespace
- if self.types:
- to_dump['types'] = [ t.to_json(names) for t in self.types ]
- if self.messages:
- messages_dict = {}
- for name, body in self.messages.items():
- messages_dict[name] = body.to_json(names)
- to_dump['messages'] = messages_dict
- return to_dump
-
- def __str__(self):
- return json.dumps(self.to_json())
-
- def __eq__(self, that):
- to_cmp = json.loads(str(self))
- return to_cmp == json.loads(str(that))
+ """An application protocol."""
+
+ def _parse_types(self, types, type_names):
+ type_objects = []
+ for type in types:
+ type_object = avro.schema.make_avsc_object(type, type_names)
+ if type_object.type not in VALID_TYPE_SCHEMA_TYPES:
+ fail_msg = 'Type %s not an enum, fixed, record, or error.' % type
+ raise ProtocolParseException(fail_msg)
+ type_objects.append(type_object)
+ return type_objects
+
+ def _parse_messages(self, messages, names):
+ message_objects = {}
+ for name, body in messages.items():
+ if name in message_objects:
+ fail_msg = 'Message name "%s" repeated.' % name
+ raise ProtocolParseException(fail_msg)
+ try:
+ request = body.get('request')
+ response = body.get('response')
+ errors = body.get('errors')
+ except AttributeError:
+ fail_msg = 'Message name "%s" has non-object body %s.' % (name, body)
+ raise ProtocolParseException(fail_msg)
+ message_objects[name] = Message(name, request, response, errors, names)
+ return message_objects
+
+ def __init__(self, name, namespace=None, types=None, messages=None):
+ # Ensure valid ctor args
+ if not name:
+ fail_msg = 'Protocols must have a non-empty name.'
+ raise ProtocolParseException(fail_msg)
+ elif not isinstance(name, basestring):
+ fail_msg = 'The name property must be a string.'
+ raise ProtocolParseException(fail_msg)
+ elif not (namespace is None or isinstance(namespace, basestring)):
+ fail_msg = 'The namespace property must be a string.'
+ raise ProtocolParseException(fail_msg)
+ elif not (types is None or isinstance(types, list)):
+ fail_msg = 'The types property must be a list.'
+ raise ProtocolParseException(fail_msg)
+ elif not (messages is None or callable(getattr(messages, 'get', None))):
+ fail_msg = 'The messages property must be a JSON object.'
+ raise ProtocolParseException(fail_msg)
+
+ self._props = {}
+ self.set_prop('name', name)
+ type_names = avro.schema.Names()
+ if namespace is not None:
+ self.set_prop('namespace', namespace)
+ type_names.default_namespace = namespace
+ if types is not None:
+ self.set_prop('types', self._parse_types(types, type_names))
+ if messages is not None:
+ self.set_prop('messages', self._parse_messages(messages, type_names))
+ self._md5 = hashlib.md5(str(self).encode()).digest()
+
+ # read-only properties
+ @property
+ def name(self):
+ return self.get_prop('name')
+
+ @property
+ def namespace(self):
+ return self.get_prop('namespace')
+
+ @property
+ def fullname(self):
+ return avro.schema.Name(self.name, self.namespace, None).fullname
+
+ @property
+ def types(self):
+ return self.get_prop('types')
+
+ @property
+ def types_dict(self):
+ return {type.name: type for type in self.types}
+
+ @property
+ def messages(self):
+ return self.get_prop('messages')
+
+ @property
+ def md5(self):
+ return self._md5
+
+ @property
+ def props(self):
+ return self._props
+
+ # utility functions to manipulate properties dict
+ def get_prop(self, key):
+ return self.props.get(key)
+
+ def set_prop(self, key, value):
+ self.props[key] = value
+
+ def to_json(self):
+ to_dump = {}
+ to_dump['protocol'] = self.name
+ names = avro.schema.Names(default_namespace=self.namespace)
+ if self.namespace:
+ to_dump['namespace'] = self.namespace
+ if self.types:
+ to_dump['types'] = [t.to_json(names) for t in self.types]
+ if self.messages:
+ messages_dict = {}
+ for name, body in self.messages.items():
+ messages_dict[name] = body.to_json(names)
+ to_dump['messages'] = messages_dict
+ return to_dump
+
+ def __str__(self):
+ return json.dumps(self.to_json())
+
+ def __eq__(self, that):
+ to_cmp = json.loads(str(self))
+ return to_cmp == json.loads(str(that))
+
class Message(object):
- """A Protocol message."""
- def _parse_request(self, request, names):
- if not isinstance(request, list):
- fail_msg = 'Request property not a list: %s' % request
- raise ProtocolParseException(fail_msg)
- return avro.schema.RecordSchema(None, None, request, names, 'request')
-
- def _parse_response(self, response, names):
- if isinstance(response, basestring) and names.has_name(response, None):
- return names.get_name(response, None)
- else:
- return avro.schema.make_avsc_object(response, names)
-
- def _parse_errors(self, errors, names):
- if not isinstance(errors, list):
- fail_msg = 'Errors property not a list: %s' % errors
- raise ProtocolParseException(fail_msg)
- errors_for_parsing = {'type': 'error_union', 'declared_errors': errors}
- return avro.schema.make_avsc_object(errors_for_parsing, names)
-
- def __init__(self, name, request, response, errors=None, names=None):
- self._name = name
-
- self._props = {}
- self.set_prop('request', self._parse_request(request, names))
- self.set_prop('response', self._parse_response(response, names))
- self.set_prop('errors', self._parse_errors(errors or [], names))
-
- # read-only properties
- name = property(lambda self: self._name)
- request = property(lambda self: self.get_prop('request'))
- response = property(lambda self: self.get_prop('response'))
- errors = property(lambda self: self.get_prop('errors'))
- props = property(lambda self: self._props)
-
- # utility functions to manipulate properties dict
- def get_prop(self, key):
- return self.props.get(key)
- def set_prop(self, key, value):
- self.props[key] = value
-
- def __str__(self):
- return json.dumps(self.to_json())
-
- def to_json(self, names=None):
- if names is None:
- names = avro.schema.Names()
- to_dump = {}
- to_dump['request'] = self.request.to_json(names)
- to_dump['response'] = self.response.to_json(names)
- if self.errors:
- to_dump['errors'] = self.errors.to_json(names)
- return to_dump
-
- def __eq__(self, that):
- return self.name == that.name and self.props == that.props
+ """A Protocol message."""
+
+ def _parse_request(self, request, names):
+ if not isinstance(request, list):
+ fail_msg = 'Request property not a list: %s' % request
+ raise ProtocolParseException(fail_msg)
+ return avro.schema.RecordSchema(None, None, request, names, 'request')
+
+ def _parse_response(self, response, names):
+ if isinstance(response, basestring) and names.has_name(response, None):
+ return names.get_name(response, None)
+ else:
+ return avro.schema.make_avsc_object(response, names)
+
+ def _parse_errors(self, errors, names):
+ if not isinstance(errors, list):
+ fail_msg = 'Errors property not a list: %s' % errors
+ raise ProtocolParseException(fail_msg)
+ errors_for_parsing = {'type': 'error_union', 'declared_errors': errors}
+ return avro.schema.make_avsc_object(errors_for_parsing, names)
+
+ def __init__(self, name, request, response, errors=None, names=None):
+ self._name = name
+
+ self._props = {}
+ self.set_prop('request', self._parse_request(request, names))
+ self.set_prop('response', self._parse_response(response, names))
+ self.set_prop('errors', self._parse_errors(errors or [], names))
+
+ # read-only properties
+ name = property(lambda self: self._name)
+ request = property(lambda self: self.get_prop('request'))
+ response = property(lambda self: self.get_prop('response'))
+ errors = property(lambda self: self.get_prop('errors'))
+ props = property(lambda self: self._props)
+
+ # utility functions to manipulate properties dict
+ def get_prop(self, key):
+ return self.props.get(key)
+
+ def set_prop(self, key, value):
+ self.props[key] = value
+
+ def __str__(self):
+ return json.dumps(self.to_json())
+
+ def to_json(self, names=None):
+ if names is None:
+ names = avro.schema.Names()
+ to_dump = {}
+ to_dump['request'] = self.request.to_json(names)
+ to_dump['response'] = self.response.to_json(names)
+ if self.errors:
+ to_dump['errors'] = self.errors.to_json(names)
+ return to_dump
+
+ def __eq__(self, that):
+ return self.name == that.name and self.props == that.props
+
def make_avpr_object(json_data):
- """Build Avro Protocol from data parsed out of JSON string."""
- try:
- name = json_data.get('protocol')
- namespace = json_data.get('namespace')
- types = json_data.get('types')
- messages = json_data.get('messages')
- except AttributeError:
- raise ProtocolParseException('Not a JSON object: %s' % json_data)
- return Protocol(name, namespace, types, messages)
+ """Build Avro Protocol from data parsed out of JSON string."""
+ try:
+ name = json_data.get('protocol')
+ namespace = json_data.get('namespace')
+ types = json_data.get('types')
+ messages = json_data.get('messages')
+ except AttributeError:
+ raise ProtocolParseException('Not a JSON object: %s' % json_data)
+ return Protocol(name, namespace, types, messages)
+
def parse(json_string):
- """Constructs the Protocol from the JSON text."""
- try:
- json_data = json.loads(json_string)
- except ValueError:
- raise ProtocolParseException('Error parsing JSON: %s' % json_string)
-
- # construct the Avro Protocol object
- return make_avpr_object(json_data)
+ """Constructs the Protocol from the JSON text."""
+ try:
+ json_data = json.loads(json_string)
+ except ValueError:
+ raise ProtocolParseException('Error parsing JSON: %s' % json_string)
+
+ # construct the Avro Protocol object
+ return make_avpr_object(json_data)
diff --git a/lang/py/avro/schema.py b/lang/py/avro/schema.py
index 1248ec5..eaada36 100644
--- a/lang/py/avro/schema.py
+++ b/lang/py/avro/schema.py
@@ -49,14 +49,14 @@ import warnings
from avro import constants
try:
- unicode
+ unicode
except NameError:
- unicode = str
+ unicode = str
try:
- basestring # type: ignore
+ basestring # type: ignore
except NameError:
- basestring = (bytes, unicode)
+ basestring = (bytes, unicode)
#
# Constants
@@ -68,824 +68,860 @@ except NameError:
_BASE_NAME_PATTERN = re.compile(r'(?:^|\.)[A-Za-z_][A-Za-z0-9_]*$')
PRIMITIVE_TYPES = (
- 'null',
- 'boolean',
- 'string',
- 'bytes',
- 'int',
- 'long',
- 'float',
- 'double',
+ 'null',
+ 'boolean',
+ 'string',
+ 'bytes',
+ 'int',
+ 'long',
+ 'float',
+ 'double',
)
NAMED_TYPES = (
- 'fixed',
- 'enum',
- 'record',
- 'error',
+ 'fixed',
+ 'enum',
+ 'record',
+ 'error',
)
VALID_TYPES = PRIMITIVE_TYPES + NAMED_TYPES + (
- 'array',
- 'map',
- 'union',
- 'request',
- 'error_union'
+ 'array',
+ 'map',
+ 'union',
+ 'request',
+ 'error_union'
)
SCHEMA_RESERVED_PROPS = (
- 'type',
- 'name',
- 'namespace',
- 'fields', # Record
- 'items', # Array
- 'size', # Fixed
- 'symbols', # Enum
- 'values', # Map
- 'doc',
+ 'type',
+ 'name',
+ 'namespace',
+ 'fields', # Record
+ 'items', # Array
+ 'size', # Fixed
+ 'symbols', # Enum
+ 'values', # Map
+ 'doc',
)
FIELD_RESERVED_PROPS = (
- 'default',
- 'name',
- 'doc',
- 'order',
- 'type',
+ 'default',
+ 'name',
+ 'doc',
+ 'order',
+ 'type',
)
VALID_FIELD_SORT_ORDERS = (
- 'ascending',
- 'descending',
- 'ignore',
+ 'ascending',
+ 'descending',
+ 'ignore',
)
#
# Exceptions
#
+
class AvroException(Exception):
- pass
+ pass
+
class SchemaParseException(AvroException):
- pass
+ pass
+
class InvalidName(SchemaParseException):
- """User attempted to parse a schema with an invalid name."""
+ """User attempted to parse a schema with an invalid name."""
+
class AvroWarning(UserWarning):
- """Base class for warnings."""
+ """Base class for warnings."""
+
class IgnoredLogicalType(AvroWarning):
- """Warnings for unknown or invalid logical types."""
+ """Warnings for unknown or invalid logical types."""
def validate_basename(basename):
- """Raise InvalidName if the given basename is not a valid name."""
- if not _BASE_NAME_PATTERN.search(basename):
- raise InvalidName("{!s} is not a valid Avro name because it "
- "does not match the pattern {!s}".format(
- basename, _BASE_NAME_PATTERN.pattern))
+ """Raise InvalidName if the given basename is not a valid name."""
+ if not _BASE_NAME_PATTERN.search(basename):
+ raise InvalidName("{!s} is not a valid Avro name because it "
+ "does not match the pattern {!s}".format(
+ basename, _BASE_NAME_PATTERN.pattern))
#
# Base Classes
#
-class Schema(object):
- """Base class for all Schema classes."""
- _props = None
- def __init__(self, type, other_props=None):
- # Ensure valid ctor args
- if not isinstance(type, basestring):
- fail_msg = 'Schema type must be a string.'
- raise SchemaParseException(fail_msg)
- elif type not in VALID_TYPES:
- fail_msg = '%s is not a valid type.' % type
- raise SchemaParseException(fail_msg)
-
- # add members
- if self._props is None:
- self._props = {}
- self.set_prop('type', type)
- self.type = type
- self._props.update(other_props or {})
-
- # Read-only properties dict. Printing schemas
- # creates JSON properties directly from this dict.
- props = property(lambda self: self._props)
-
- # Read-only property dict. Non-reserved properties
- other_props = property(lambda self: get_other_props(self._props, SCHEMA_RESERVED_PROPS),
- doc="dictionary of non-reserved properties")
-
- def check_props(self, other, props):
- """Check that the given props are identical in two schemas.
-
- @arg other: The other schema to check
- @arg props: An iterable of properties to check
- @return bool: True if all the properties match
- """
- return all(getattr(self, prop) == getattr(other, prop) for prop in props)
-
- def match(self, writer):
- """Return True if the current schema (as reader) matches the writer schema.
-
- @arg writer: the writer schema to match against.
- @return bool
- """
- raise NotImplemented("Must be implemented by subclasses")
-
- # utility functions to manipulate properties dict
- def get_prop(self, key):
- return self._props.get(key)
-
- def set_prop(self, key, value):
- self._props[key] = value
-
- def __str__(self):
- return json.dumps(self.to_json())
-
- def to_json(self, names):
- """
- Converts the schema object into its AVRO specification representation.
-
- Schema types that have names (records, enums, and fixed) must
- be aware of not re-defining schemas that are already listed
- in the parameter names.
- """
- raise Exception("Must be implemented by subclasses.")
+class Schema(object):
+ """Base class for all Schema classes."""
+ _props = None
+
+ def __init__(self, type, other_props=None):
+ # Ensure valid ctor args
+ if not isinstance(type, basestring):
+ fail_msg = 'Schema type must be a string.'
+ raise SchemaParseException(fail_msg)
+ elif type not in VALID_TYPES:
+ fail_msg = '%s is not a valid type.' % type
+ raise SchemaParseException(fail_msg)
+
+ # add members
+ if self._props is None:
+ self._props = {}
+ self.set_prop('type', type)
+ self.type = type
+ self._props.update(other_props or {})
+
+ # Read-only properties dict. Printing schemas
+ # creates JSON properties directly from this dict.
+ props = property(lambda self: self._props)
+
+ # Read-only property dict. Non-reserved properties
+ other_props = property(lambda self: get_other_props(self._props, SCHEMA_RESERVED_PROPS),
+ doc="dictionary of non-reserved properties")
+
+ def check_props(self, other, props):
+ """Check that the given props are identical in two schemas.
+
+ @arg other: The other schema to check
+ @arg props: An iterable of properties to check
+ @return bool: True if all the properties match
+ """
+ return all(getattr(self, prop) == getattr(other, prop) for prop in props)
+
+ def match(self, writer):
+ """Return True if the current schema (as reader) matches the writer schema.
+
+ @arg writer: the writer schema to match against.
+ @return bool
+ """
+ raise NotImplemented("Must be implemented by subclasses")
+
+ # utility functions to manipulate properties dict
+ def get_prop(self, key):
+ return self._props.get(key)
+
+ def set_prop(self, key, value):
+ self._props[key] = value
+
+ def __str__(self):
+ return json.dumps(self.to_json())
+
+ def to_json(self, names):
+ """
+ Converts the schema object into its AVRO specification representation.
+
+ Schema types that have names (records, enums, and fixed) must
+ be aware of not re-defining schemas that are already listed
+ in the parameter names.
+ """
+ raise Exception("Must be implemented by subclasses.")
class Name(object):
- """Class to describe Avro name."""
-
- _full = None
-
- def __init__(self, name_attr, space_attr, default_space):
- """The fullname is determined in one of the following ways:
-
- - A name and namespace are both specified. For example, one might use "name": "X", "namespace": "org.foo" to indicate the fullname org.foo.X.
- - A fullname is specified. If the name specified contains a dot, then it is assumed to be a fullname, and any namespace also specified is ignored. For example, use "name": "org.foo.X" to indicate the fullname org.foo.X.
- - A name only is specified, i.e., a name that contains no dots. In this case the namespace is taken from the most tightly enclosing schema or protocol. For example, if "name": "X" is specified, and this occurs within a field of the record definition of org.foo.Y, then the fullname is org.foo.X. If there is no enclosing namespace then the null namespace is used.
-
- References to previously defined names are as in the latter two cases above: if they contain a dot they are a fullname, if they do not contain a dot, the namespace is the namespace of the enclosing definition.
-
- @arg name_attr: name value read in schema or None.
- @arg space_attr: namespace value read in schema or None. The empty string may be used as a namespace to indicate the null namespace.
- @arg default_space: the current default space or None.
- """
- if name_attr is None:
- return
- if name_attr == "":
- raise SchemaParseException('Name must not be the empty string.')
-
- if '.' in name_attr or space_attr == "" or not (space_attr or default_space):
- # The empty string may be used as a namespace to indicate the null namespace.
- self._full = name_attr
- else:
- self._full = "{!s}.{!s}".format(space_attr or default_space, name_attr)
-
- self._validate_fullname(self._full)
-
- def _validate_fullname(self, fullname):
- for name in fullname.split('.'):
- validate_basename(name)
-
- def __eq__(self, other):
- """Equality of names is defined on the fullname and is case-sensitive."""
- return isinstance(other, Name) and self.fullname == other.fullname
-
- @property
- def fullname(self):
- return self._full
-
- @property
- def space(self):
- """Back out a namespace from full name."""
- if self._full is None:
- return None
- return self._full.rsplit(".", 1)[0] if "." in self._full else None
-
- def get_space(self):
- warnings.warn('Name.get_space() is deprecated in favor of Name.space')
- return self.space
+ """Class to describe Avro name."""
+
+ _full = None
+
+ def __init__(self, name_attr, space_attr, default_space):
+ """The fullname is determined in one of the following ways:
+
+ - A name and namespace are both specified. For example, one might use "name": "X", "namespace": "org.foo" to indicate the fullname org.foo.X.
+ - A fullname is specified. If the name specified contains a dot,
+ then it is assumed to be a fullname, and any namespace also specified is ignored.
+ For example, use "name": "org.foo.X" to indicate the fullname org.foo.X.
+ - A name only is specified, i.e., a name that contains no dots.
+ In this case the namespace is taken from the most tightly enclosing schema or protocol.
+ For example, if "name": "X" is specified, and this occurs within a field of
+ the record definition of org.foo.Y, then the fullname is org.foo.X.
+ If there is no enclosing namespace then the null namespace is used.
+
+ References to previously defined names are as in the latter two cases above:
+ if they contain a dot they are a fullname,
+ if they do not contain a dot, the namespace is the namespace of the enclosing definition.
+
+ @arg name_attr: name value read in schema or None.
+ @arg space_attr: namespace value read in schema or None. The empty string may be used as a namespace to indicate the null namespace.
+ @arg default_space: the current default space or None.
+ """
+ if name_attr is None:
+ return
+ if name_attr == "":
+ raise SchemaParseException('Name must not be the empty string.')
+
+ if '.' in name_attr or space_attr == "" or not (space_attr or default_space):
+ # The empty string may be used as a namespace to indicate the null namespace.
+ self._full = name_attr
+ else:
+ self._full = "{!s}.{!s}".format(space_attr or default_space, name_attr)
+
+ self._validate_fullname(self._full)
+
+ def _validate_fullname(self, fullname):
+ for name in fullname.split('.'):
+ validate_basename(name)
+
+ def __eq__(self, other):
+ """Equality of names is defined on the fullname and is case-sensitive."""
+ return isinstance(other, Name) and self.fullname == other.fullname
+
+ @property
+ def fullname(self):
+ return self._full
+
+ @property
+ def space(self):
+ """Back out a namespace from full name."""
+ if self._full is None:
+ return None
+ return self._full.rsplit(".", 1)[0] if "." in self._full else None
+
+ def get_space(self):
+ warnings.warn('Name.get_space() is deprecated in favor of Name.space')
+ return self.space
class Names(object):
- """Track name set and default namespace during parsing."""
- def __init__(self, default_namespace=None):
- self.names = {}
- self.default_namespace = default_namespace
-
- def has_name(self, name_attr, space_attr):
- test = Name(name_attr, space_attr, self.default_namespace).fullname
- return test in self.names
-
- def get_name(self, name_attr, space_attr):
- test = Name(name_attr, space_attr, self.default_namespace).fullname
- if test not in self.names:
- return None
- return self.names[test]
-
- def prune_namespace(self, properties):
- """given a properties, return properties with namespace removed if
- it matches the own default namespace"""
- if self.default_namespace is None:
- # I have no default -- no change
- return properties
- if 'namespace' not in properties:
- # he has no namespace - no change
- return properties
- if properties['namespace'] != self.default_namespace:
- # we're different - leave his stuff alone
- return properties
- # we each have a namespace and it's redundant. delete his.
- prunable = properties.copy()
- del(prunable['namespace'])
- return prunable
-
- def add_name(self, name_attr, space_attr, new_schema):
- """
- Add a new schema object to the name set.
-
- @arg name_attr: name value read in schema
- @arg space_attr: namespace value read in schema.
+ """Track name set and default namespace during parsing."""
+
+ def __init__(self, default_namespace=None):
+ self.names = {}
+ self.default_namespace = default_namespace
+
+ def has_name(self, name_attr, space_attr):
+ test = Name(name_attr, space_attr, self.default_namespace).fullname
+ return test in self.names
+
+ def get_name(self, name_attr, space_attr):
+ test = Name(name_attr, space_attr, self.default_namespace).fullname
+ if test not in self.names:
+ return None
+ return self.names[test]
+
+ def prune_namespace(self, properties):
+ """given a properties, return properties with namespace removed if
+ it matches the own default namespace"""
+ if self.default_namespace is None:
+ # I have no default -- no change
+ return properties
+ if 'namespace' not in properties:
+ # he has no namespace - no change
+ return properties
+ if properties['namespace'] != self.default_namespace:
+ # we're different - leave his stuff alone
+ return properties
+ # we each have a namespace and it's redundant. delete his.
+ prunable = properties.copy()
+ del(prunable['namespace'])
+ return prunable
+
+ def add_name(self, name_attr, space_attr, new_schema):
+ """
+ Add a new schema object to the name set.
+
+ @arg name_attr: name value read in schema
+ @arg space_attr: namespace value read in schema.
+
+ @return: the Name that was just added.
+ """
+ to_add = Name(name_attr, space_attr, self.default_namespace)
+
+ if to_add.fullname in VALID_TYPES:
+ fail_msg = '%s is a reserved type name.' % to_add.fullname
+ raise SchemaParseException(fail_msg)
+ elif to_add.fullname in self.names:
+ fail_msg = 'The name "%s" is already in use.' % to_add.fullname
+ raise SchemaParseException(fail_msg)
+
+ self.names[to_add.fullname] = new_schema
+ return to_add
- @return: the Name that was just added.
- """
- to_add = Name(name_attr, space_attr, self.default_namespace)
-
- if to_add.fullname in VALID_TYPES:
- fail_msg = '%s is a reserved type name.' % to_add.fullname
- raise SchemaParseException(fail_msg)
- elif to_add.fullname in self.names:
- fail_msg = 'The name "%s" is already in use.' % to_add.fullname
- raise SchemaParseException(fail_msg)
-
- self.names[to_add.fullname] = new_schema
- return to_add
class NamedSchema(Schema):
- """Named Schemas specified in NAMED_TYPES."""
- def __init__(self, type, name, namespace=None, names=None, other_props=None):
- # Ensure valid ctor args
- if not name:
- fail_msg = 'Named Schemas must have a non-empty name.'
- raise SchemaParseException(fail_msg)
- elif not isinstance(name, basestring):
- fail_msg = 'The name property must be a string.'
- raise SchemaParseException(fail_msg)
- elif namespace is not None and not isinstance(namespace, basestring):
- fail_msg = 'The namespace property must be a string.'
- raise SchemaParseException(fail_msg)
-
- # Call parent ctor
- Schema.__init__(self, type, other_props)
-
- # Add class members
- new_name = names.add_name(name, namespace, self)
-
- # Store name and namespace as they were read in origin schema
- self.set_prop('name', name)
- if namespace is not None:
- self.set_prop('namespace', new_name.space)
-
- # Store full name as calculated from name, namespace
- self._fullname = new_name.fullname
-
- def name_ref(self, names):
- return self.name if self.namespace == names.default_namespace else self.fullname
-
- # read-only properties
- name = property(lambda self: self.get_prop('name'))
- namespace = property(lambda self: self.get_prop('namespace'))
- fullname = property(lambda self: self._fullname)
+ """Named Schemas specified in NAMED_TYPES."""
+
+ def __init__(self, type, name, namespace=None, names=None, other_props=None):
+ # Ensure valid ctor args
+ if not name:
+ fail_msg = 'Named Schemas must have a non-empty name.'
+ raise SchemaParseException(fail_msg)
+ elif not isinstance(name, basestring):
+ fail_msg = 'The name property must be a string.'
+ raise SchemaParseException(fail_msg)
+ elif namespace is not None and not isinstance(namespace, basestring):
+ fail_msg = 'The namespace property must be a string.'
+ raise SchemaParseException(fail_msg)
+
+ # Call parent ctor
+ Schema.__init__(self, type, other_props)
+
+ # Add class members
+ new_name = names.add_name(name, namespace, self)
+
+ # Store name and namespace as they were read in origin schema
+ self.set_prop('name', name)
+ if namespace is not None:
+ self.set_prop('namespace', new_name.space)
+
+ # Store full name as calculated from name, namespace
+ self._fullname = new_name.fullname
+
+ def name_ref(self, names):
+ return self.name if self.namespace == names.default_namespace else self.fullname
+
+ # read-only properties
+ name = property(lambda self: self.get_prop('name'))
+ namespace = property(lambda self: self.get_prop('namespace'))
+ fullname = property(lambda self: self._fullname)
#
# Logical type class
#
+
class LogicalSchema(object):
- def __init__(self, logical_type):
- self.logical_type = logical_type
+ def __init__(self, logical_type):
+ self.logical_type = logical_type
#
# Decimal logical schema
#
+
class DecimalLogicalSchema(LogicalSchema):
- def __init__(self, precision, scale=0, max_precision=0):
- if not isinstance(precision, int) or precision <= 0:
- raise IgnoredLogicalType(
- "Invalid decimal precision {}. Must be a positive integer.".format(precision))
+ def __init__(self, precision, scale=0, max_precision=0):
+ if not isinstance(precision, int) or precision <= 0:
+ raise IgnoredLogicalType(
+ "Invalid decimal precision {}. Must be a positive integer.".format(precision))
- if precision > max_precision:
- raise IgnoredLogicalType(
- "Invalid decimal precision {}. Max is {}.".format(precision, max_precision))
+ if precision > max_precision:
+ raise IgnoredLogicalType(
+ "Invalid decimal precision {}. Max is {}.".format(precision, max_precision))
- if not isinstance(scale, int) or scale < 0:
- raise IgnoredLogicalType(
- "Invalid decimal scale {}. Must be a positive integer.".format(scale))
+ if not isinstance(scale, int) or scale < 0:
+ raise IgnoredLogicalType(
+ "Invalid decimal scale {}. Must be a positive integer.".format(scale))
- if scale > precision:
- raise IgnoredLogicalType("Invalid decimal scale {}. Cannot be greater than precision {}."
- .format(scale, precision))
+ if scale > precision:
+ raise IgnoredLogicalType("Invalid decimal scale {}. Cannot be greater than precision {}."
+ .format(scale, precision))
- super(DecimalLogicalSchema, self).__init__('decimal')
+ super(DecimalLogicalSchema, self).__init__('decimal')
class Field(object):
- def __init__(self, type, name, has_default, default=None,
- order=None,names=None, doc=None, other_props=None):
- # Ensure valid ctor args
- if not name:
- fail_msg = 'Fields must have a non-empty name.'
- raise SchemaParseException(fail_msg)
- elif not isinstance(name, basestring):
- fail_msg = 'The name property must be a string.'
- raise SchemaParseException(fail_msg)
- elif order is not None and order not in VALID_FIELD_SORT_ORDERS:
- fail_msg = 'The order property %s is not valid.' % order
- raise SchemaParseException(fail_msg)
-
- # add members
- self._props = {}
- self._has_default = has_default
- self._props.update(other_props or {})
-
- if (isinstance(type, basestring) and names is not None
- and names.has_name(type, None)):
- type_schema = names.get_name(type, None)
- else:
- try:
- type_schema = make_avsc_object(type, names)
- except Exception as e:
- fail_msg = 'Type property "%s" not a valid Avro schema: %s' % (type, e)
- raise SchemaParseException(fail_msg)
- self.set_prop('type', type_schema)
- self.set_prop('name', name)
- self.type = type_schema
- self.name = name
- # TODO(hammer): check to ensure default is valid
- if has_default: self.set_prop('default', default)
- if order is not None: self.set_prop('order', order)
- if doc is not None: self.set_prop('doc', doc)
-
- # read-only properties
- default = property(lambda self: self.get_prop('default'))
- has_default = property(lambda self: self._has_default)
- order = property(lambda self: self.get_prop('order'))
- doc = property(lambda self: self.get_prop('doc'))
- props = property(lambda self: self._props)
-
- # Read-only property dict. Non-reserved properties
- other_props = property(lambda self: get_other_props(self._props, FIELD_RESERVED_PROPS),
- doc="dictionary of non-reserved properties")
+ def __init__(self, type, name, has_default, default=None,
+ order=None, names=None, doc=None, other_props=None):
+ # Ensure valid ctor args
+ if not name:
+ fail_msg = 'Fields must have a non-empty name.'
+ raise SchemaParseException(fail_msg)
+ elif not isinstance(name, basestring):
+ fail_msg = 'The name property must be a string.'
+ raise SchemaParseException(fail_msg)
+ elif order is not None and order not in VALID_FIELD_SORT_ORDERS:
+ fail_msg = 'The order property %s is not valid.' % order
+ raise SchemaParseException(fail_msg)
+
+ # add members
+ self._props = {}
+ self._has_default = has_default
+ self._props.update(other_props or {})
+
+ if (isinstance(type, basestring) and names is not None and
+ names.has_name(type, None)):
+ type_schema = names.get_name(type, None)
+ else:
+ try:
+ type_schema = make_avsc_object(type, names)
+ except Exception as e:
+ fail_msg = 'Type property "%s" not a valid Avro schema: %s' % (type, e)
+ raise SchemaParseException(fail_msg)
+ self.set_prop('type', type_schema)
+ self.set_prop('name', name)
+ self.type = type_schema
+ self.name = name
+ # TODO(hammer): check to ensure default is valid
+ if has_default:
+ self.set_prop('default', default)
+ if order is not None:
+ self.set_prop('order', order)
+ if doc is not None:
+ self.set_prop('doc', doc)
+
+ # read-only properties
+ default = property(lambda self: self.get_prop('default'))
+ has_default = property(lambda self: self._has_default)
+ order = property(lambda self: self.get_prop('order'))
+ doc = property(lambda self: self.get_prop('doc'))
+ props = property(lambda self: self._props)
+
+ # Read-only property dict. Non-reserved properties
+ other_props = property(lambda self: get_other_props(self._props, FIELD_RESERVED_PROPS),
+ doc="dictionary of non-reserved properties")
# utility functions to manipulate properties dict
- def get_prop(self, key):
- return self._props.get(key)
- def set_prop(self, key, value):
- self._props[key] = value
+ def get_prop(self, key):
+ return self._props.get(key)
- def __str__(self):
- return json.dumps(self.to_json())
+ def set_prop(self, key, value):
+ self._props[key] = value
- def to_json(self, names=None):
- if names is None:
- names = Names()
- to_dump = self.props.copy()
- to_dump['type'] = self.type.to_json(names)
- return to_dump
+ def __str__(self):
+ return json.dumps(self.to_json())
- def __eq__(self, that):
- to_cmp = json.loads(str(self))
- return to_cmp == json.loads(str(that))
+ def to_json(self, names=None):
+ if names is None:
+ names = Names()
+ to_dump = self.props.copy()
+ to_dump['type'] = self.type.to_json(names)
+ return to_dump
+
+ def __eq__(self, that):
+ to_cmp = json.loads(str(self))
+ return to_cmp == json.loads(str(that))
#
# Primitive Types
#
+
+
class PrimitiveSchema(Schema):
- """Valid primitive types are in PRIMITIVE_TYPES."""
- def __init__(self, type, other_props=None):
- # Ensure valid ctor args
- if type not in PRIMITIVE_TYPES:
- raise AvroException("%s is not a valid primitive type." % type)
+ """Valid primitive types are in PRIMITIVE_TYPES."""
- # Call parent ctor
- Schema.__init__(self, type, other_props=other_props)
+ def __init__(self, type, other_props=None):
+ # Ensure valid ctor args
+ if type not in PRIMITIVE_TYPES:
+ raise AvroException("%s is not a valid primitive type." % type)
- self.fullname = type
+ # Call parent ctor
+ Schema.__init__(self, type, other_props=other_props)
- def match(self, writer):
- """Return True if the current schema (as reader) matches the writer schema.
+ self.fullname = type
- @arg writer: the schema to match against
- @return bool
- """
- return self.type == writer.type or {
- 'float': self.type == 'double',
- 'int': self.type in {'double', 'float', 'long'},
- 'long': self.type in {'double', 'float',},
- }.get(writer.type, False)
+ def match(self, writer):
+ """Return True if the current schema (as reader) matches the writer schema.
+ @arg writer: the schema to match against
+ @return bool
+ """
+ return self.type == writer.type or {
+ 'float': self.type == 'double',
+ 'int': self.type in {'double', 'float', 'long'},
+ 'long': self.type in {'double', 'float', },
+ }.get(writer.type, False)
- def to_json(self, names=None):
- if len(self.props) == 1:
- return self.fullname
- else:
- return self.props
+ def to_json(self, names=None):
+ if len(self.props) == 1:
+ return self.fullname
+ else:
+ return self.props
- def __eq__(self, that):
- return self.props == that.props
+ def __eq__(self, that):
+ return self.props == that.props
#
# Decimal Bytes Type
#
+
+
class BytesDecimalSchema(PrimitiveSchema, DecimalLogicalSchema):
- def __init__(self, precision, scale=0, other_props=None):
- DecimalLogicalSchema.__init__(self, precision, scale, max_precision=((1 << 31) - 1))
- PrimitiveSchema.__init__(self, 'bytes', other_props)
- self.set_prop('precision', precision)
- self.set_prop('scale', scale)
+ def __init__(self, precision, scale=0, other_props=None):
+ DecimalLogicalSchema.__init__(self, precision, scale, max_precision=((1 << 31) - 1))
+ PrimitiveSchema.__init__(self, 'bytes', other_props)
+ self.set_prop('precision', precision)
+ self.set_prop('scale', scale)
- # read-only properties
- precision = property(lambda self: self.get_prop('precision'))
- scale = property(lambda self: self.get_prop('scale'))
+ # read-only properties
+ precision = property(lambda self: self.get_prop('precision'))
+ scale = property(lambda self: self.get_prop('scale'))
- def to_json(self, names=None):
- return self.props
+ def to_json(self, names=None):
+ return self.props
- def __eq__(self, that):
- return self.props == that.props
+ def __eq__(self, that):
+ return self.props == that.props
#
# Complex Types (non-recursive)
#
class FixedSchema(NamedSchema):
- def __init__(self, name, namespace, size, names=None, other_props=None):
- # Ensure valid ctor args
- if not isinstance(size, int) or size < 0:
- fail_msg = 'Fixed Schema requires a valid positive integer for size property.'
- raise AvroException(fail_msg)
+ def __init__(self, name, namespace, size, names=None, other_props=None):
+ # Ensure valid ctor args
+ if not isinstance(size, int) or size < 0:
+ fail_msg = 'Fixed Schema requires a valid positive integer for size property.'
+ raise AvroException(fail_msg)
- # Call parent ctor
- NamedSchema.__init__(self, 'fixed', name, namespace, names, other_props)
+ # Call parent ctor
+ NamedSchema.__init__(self, 'fixed', name, namespace, names, other_props)
- # Add class members
- self.set_prop('size', size)
+ # Add class members
+ self.set_prop('size', size)
- # read-only properties
- size = property(lambda self: self.get_prop('size'))
+ # read-only properties
+ size = property(lambda self: self.get_prop('size'))
- def match(self, writer):
- """Return True if the current schema (as reader) matches the writer schema.
+ def match(self, writer):
+ """Return True if the current schema (as reader) matches the writer schema.
- @arg writer: the schema to match against
- @return bool
- """
- return self.type == writer.type and self.check_props(writer, ['fullname', 'size'])
+ @arg writer: the schema to match against
+ @return bool
+ """
+ return self.type == writer.type and self.check_props(writer, ['fullname', 'size'])
- def to_json(self, names=None):
- if names is None:
- names = Names()
- if self.fullname in names.names:
- return self.name_ref(names)
- else:
- names.names[self.fullname] = self
- return names.prune_namespace(self.props)
+ def to_json(self, names=None):
+ if names is None:
+ names = Names()
+ if self.fullname in names.names:
+ return self.name_ref(names)
+ else:
+ names.names[self.fullname] = self
+ return names.prune_namespace(self.props)
- def __eq__(self, that):
- return self.props == that.props
+ def __eq__(self, that):
+ return self.props == that.props
#
# Decimal Fixed Type
#
+
class FixedDecimalSchema(FixedSchema, DecimalLogicalSchema):
- def __init__(self, size, name, precision, scale=0, namespace=None, names=None, other_props=None):
- max_precision = int(math.floor(math.log10(2) * (8 * size - 1)))
- DecimalLogicalSchema.__init__(self, precision, scale, max_precision)
- FixedSchema.__init__(self, name, namespace, size, names, other_props)
- self.set_prop('precision', precision)
- self.set_prop('scale', scale)
+ def __init__(self, size, name, precision, scale=0, namespace=None, names=None, other_props=None):
+ max_precision = int(math.floor(math.log10(2) * (8 * size - 1)))
+ DecimalLogicalSchema.__init__(self, precision, scale, max_precision)
+ FixedSchema.__init__(self, name, namespace, size, names, other_props)
+ self.set_prop('precision', precision)
+ self.set_prop('scale', scale)
- # read-only properties
- precision = property(lambda self: self.get_prop('precision'))
- scale = property(lambda self: self.get_prop('scale'))
+ # read-only properties
+ precision = property(lambda self: self.get_prop('precision'))
+ scale = property(lambda self: self.get_prop('scale'))
- def to_json(self, names=None):
- return self.props
+ def to_json(self, names=None):
+ return self.props
- def __eq__(self, that):
- return self.props == that.props
+ def __eq__(self, that):
+ return self.props == that.props
class EnumSchema(NamedSchema):
- def __init__(self, name, namespace, symbols, names=None, doc=None, other_props=None, validate_enum_symbols=True):
- """
- @arg validate_enum_symbols: If False, will allow enum symbols that are not valid Avro names.
- """
- if validate_enum_symbols:
- for symbol in symbols:
- try:
- validate_basename(symbol)
- except InvalidName:
- raise InvalidName("An enum symbol must be a valid schema name.")
-
- if len(set(symbols)) < len(symbols):
- fail_msg = 'Duplicate symbol: %s' % symbols
- raise AvroException(fail_msg)
-
- # Call parent ctor
- NamedSchema.__init__(self, 'enum', name, namespace, names, other_props)
-
- # Add class members
- self.set_prop('symbols', symbols)
- if doc is not None:
- self.set_prop('doc', doc)
-
- # read-only properties
- symbols = property(lambda self: self.get_prop('symbols'))
- doc = property(lambda self: self.get_prop('doc'))
-
- def match(self, writer):
- """Return True if the current schema (as reader) matches the writer schema.
-
- @arg writer: the schema to match against
- @return bool
- """
- return self.type == writer.type and self.check_props(writer, ['fullname'])
-
- def to_json(self, names=None):
- if names is None:
- names = Names()
- if self.fullname in names.names:
- return self.name_ref(names)
- else:
- names.names[self.fullname] = self
- return names.prune_namespace(self.props)
-
- def __eq__(self, that):
- return self.props == that.props
+ def __init__(self, name, namespace, symbols, names=None, doc=None, other_props=None, validate_enum_symbols=True):
+ """
+ @arg validate_enum_symbols: If False, will allow enum symbols that are not valid Avro names.
+ """
+ if validate_enum_symbols:
+ for symbol in symbols:
+ try:
+ validate_basename(symbol)
+ except InvalidName:
+ raise InvalidName("An enum symbol must be a valid schema name.")
+
+ if len(set(symbols)) < len(symbols):
+ fail_msg = 'Duplicate symbol: %s' % symbols
+ raise AvroException(fail_msg)
+
+ # Call parent ctor
+ NamedSchema.__init__(self, 'enum', name, namespace, names, other_props)
+
+ # Add class members
+ self.set_prop('symbols', symbols)
+ if doc is not None:
+ self.set_prop('doc', doc)
+
+ # read-only properties
+ symbols = property(lambda self: self.get_prop('symbols'))
+ doc = property(lambda self: self.get_prop('doc'))
+
+ def match(self, writer):
+ """Return True if the current schema (as reader) matches the writer schema.
+
+ @arg writer: the schema to match against
+ @return bool
+ """
+ return self.type == writer.type and self.check_props(writer, ['fullname'])
+
+ def to_json(self, names=None):
+ if names is None:
+ names = Names()
+ if self.fullname in names.names:
+ return self.name_ref(names)
+ else:
+ names.names[self.fullname] = self
+ return names.prune_namespace(self.props)
+
+ def __eq__(self, that):
+ return self.props == that.props
#
# Complex Types (recursive)
#
-class ArraySchema(Schema):
- def __init__(self, items, names=None, other_props=None):
- # Call parent ctor
- Schema.__init__(self, 'array', other_props)
- # Add class members
-
- if isinstance(items, basestring) and names.has_name(items, None):
- items_schema = names.get_name(items, None)
- else:
- try:
- items_schema = make_avsc_object(items, names)
- except SchemaParseException as e:
- fail_msg = 'Items schema (%s) not a valid Avro schema: %s (known names: %s)' % (items, e, names.names.keys())
- raise SchemaParseException(fail_msg)
-
- self.set_prop('items', items_schema)
-
- # read-only properties
- items = property(lambda self: self.get_prop('items'))
- def match(self, writer):
- """Return True if the current schema (as reader) matches the writer schema.
-
- @arg writer: the schema to match against
- @return bool
- """
- return self.type == writer.type and self.items.check_props(writer.items, ['type'])
-
- def to_json(self, names=None):
- if names is None:
- names = Names()
- to_dump = self.props.copy()
- item_schema = self.get_prop('items')
- to_dump['items'] = item_schema.to_json(names)
- return to_dump
+class ArraySchema(Schema):
+ def __init__(self, items, names=None, other_props=None):
+ # Call parent ctor
+ Schema.__init__(self, 'array', other_props)
+ # Add class members
+
+ if isinstance(items, basestring) and names.has_name(items, None):
+ items_schema = names.get_name(items, None)
+ else:
+ try:
+ items_schema = make_avsc_object(items, names)
+ except SchemaParseException as e:
+ fail_msg = 'Items schema (%s) not a valid Avro schema: %s (known names: %s)' % (items, e, names.names.keys())
+ raise SchemaParseException(fail_msg)
+
+ self.set_prop('items', items_schema)
+
+ # read-only properties
+ items = property(lambda self: self.get_prop('items'))
+
+ def match(self, writer):
+ """Return True if the current schema (as reader) matches the writer schema.
+
+ @arg writer: the schema to match against
+ @return bool
+ """
+ return self.type == writer.type and self.items.check_props(writer.items, ['type'])
+
+ def to_json(self, names=None):
+ if names is None:
+ names = Names()
+ to_dump = self.props.copy()
+ item_schema = self.get_prop('items')
+ to_dump['items'] = item_schema.to_json(names)
+ return to_dump
+
+ def __eq__(self, that):
+ to_cmp = json.loads(str(self))
+ return to_cmp == json.loads(str(that))
- def __eq__(self, that):
- to_cmp = json.loads(str(self))
- return to_cmp == json.loads(str(that))
class MapSchema(Schema):
- def __init__(self, values, names=None, other_props=None):
- # Call parent ctor
- Schema.__init__(self, 'map',other_props)
-
- # Add class members
- if isinstance(values, basestring) and names.has_name(values, None):
- values_schema = names.get_name(values, None)
- else:
- try:
- values_schema = make_avsc_object(values, names)
- except SchemaParseException:
- raise
- except Exception:
- raise SchemaParseException('Values schema is not a valid Avro schema.')
+ def __init__(self, values, names=None, other_props=None):
+ # Call parent ctor
+ Schema.__init__(self, 'map', other_props)
+
+ # Add class members
+ if isinstance(values, basestring) and names.has_name(values, None):
+ values_schema = names.get_name(values, None)
+ else:
+ try:
+ values_schema = make_avsc_object(values, names)
+ except SchemaParseException:
+ raise
+ except Exception:
+ raise SchemaParseException('Values schema is not a valid Avro schema.')
+
+ self.set_prop('values', values_schema)
+
+ # read-only properties
+ values = property(lambda self: self.get_prop('values'))
+
+ def match(self, writer):
+ """Return True if the current schema (as reader) matches the writer schema.
+
+ @arg writer: the schema to match against
+ @return bool
+ """
+ return writer.type == self.type and self.values.check_props(writer.values, ['type'])
+
+ def to_json(self, names=None):
+ if names is None:
+ names = Names()
+ to_dump = self.props.copy()
+ to_dump['values'] = self.get_prop('values').to_json(names)
+ return to_dump
+
+ def __eq__(self, that):
+ to_cmp = json.loads(str(self))
+ return to_cmp == json.loads(str(that))
- self.set_prop('values', values_schema)
-
- # read-only properties
- values = property(lambda self: self.get_prop('values'))
-
- def match(self, writer):
- """Return True if the current schema (as reader) matches the writer schema.
-
- @arg writer: the schema to match against
- @return bool
- """
- return writer.type == self.type and self.values.check_props(writer.values, ['type'])
-
- def to_json(self, names=None):
- if names is None:
- names = Names()
- to_dump = self.props.copy()
- to_dump['values'] = self.get_prop('values').to_json(names)
- return to_dump
-
- def __eq__(self, that):
- to_cmp = json.loads(str(self))
- return to_cmp == json.loads(str(that))
class UnionSchema(Schema):
- """
- names is a dictionary of schema objects
- """
- def __init__(self, schemas, names=None):
- # Ensure valid ctor args
- if not isinstance(schemas, list):
- fail_msg = 'Union schema requires a list of schemas.'
- raise SchemaParseException(fail_msg)
-
- # Call parent ctor
- Schema.__init__(self, 'union')
-
- # Add class members
- schema_objects = []
- for schema in schemas:
- if isinstance(schema, basestring) and names.has_name(schema, None):
- new_schema = names.get_name(schema, None)
- else:
- try:
- new_schema = make_avsc_object(schema, names)
- except Exception as e:
- raise SchemaParseException('Union item must be a valid Avro schema: %s' % str(e))
- # check the new schema
- if (new_schema.type in VALID_TYPES and new_schema.type not in NAMED_TYPES
- and new_schema.type in [schema.type for schema in schema_objects]):
- raise SchemaParseException('%s type already in Union' % new_schema.type)
- elif new_schema.type == 'union':
- raise SchemaParseException('Unions cannot contain other unions.')
- else:
- schema_objects.append(new_schema)
- self._schemas = schema_objects
-
- # read-only properties
- schemas = property(lambda self: self._schemas)
-
- def match(self, writer):
- """Return True if the current schema (as reader) matches the writer schema.
-
- @arg writer: the schema to match against
- @return bool
"""
- return writer.type in {'union', 'error_union'} or any(s.match(writer) for s in self.schemas)
+ names is a dictionary of schema objects
+ """
- def to_json(self, names=None):
- if names is None:
- names = Names()
- to_dump = []
- for schema in self.schemas:
- to_dump.append(schema.to_json(names))
- return to_dump
+ def __init__(self, schemas, names=None):
+ # Ensure valid ctor args
+ if not isinstance(schemas, list):
+ fail_msg = 'Union schema requires a list of schemas.'
+ raise SchemaParseException(fail_msg)
+
+ # Call parent ctor
+ Schema.__init__(self, 'union')
+
+ # Add class members
+ schema_objects = []
+ for schema in schemas:
+ if isinstance(schema, basestring) and names.has_name(schema, None):
+ new_schema = names.get_name(schema, None)
+ else:
+ try:
+ new_schema = make_avsc_object(schema, names)
+ except Exception as e:
+ raise SchemaParseException('Union item must be a valid Avro schema: %s' % str(e))
+ # check the new schema
+ if (new_schema.type in VALID_TYPES and new_schema.type not in NAMED_TYPES and
+ new_schema.type in [schema.type for schema in schema_objects]):
+ raise SchemaParseException('%s type already in Union' % new_schema.type)
+ elif new_schema.type == 'union':
+ raise SchemaParseException('Unions cannot contain other unions.')
+ else:
+ schema_objects.append(new_schema)
+ self._schemas = schema_objects
+
+ # read-only properties
+ schemas = property(lambda self: self._schemas)
+
+ def match(self, writer):
+ """Return True if the current schema (as reader) matches the writer schema.
+
+ @arg writer: the schema to match against
+ @return bool
+ """
+ return writer.type in {'union', 'error_union'} or any(s.match(writer) for s in self.schemas)
+
+ def to_json(self, names=None):
+ if names is None:
+ names = Names()
+ to_dump = []
+ for schema in self.schemas:
+ to_dump.append(schema.to_json(names))
+ return to_dump
+
+ def __eq__(self, that):
+ to_cmp = json.loads(str(self))
+ return to_cmp == json.loads(str(that))
- def __eq__(self, that):
- to_cmp = json.loads(str(self))
- return to_cmp == json.loads(str(that))
class ErrorUnionSchema(UnionSchema):
- def __init__(self, schemas, names=None):
- # Prepend "string" to handle system errors
- UnionSchema.__init__(self, ['string'] + schemas, names)
+ def __init__(self, schemas, names=None):
+ # Prepend "string" to handle system errors
+ UnionSchema.__init__(self, ['string'] + schemas, names)
+
+ def to_json(self, names=None):
+ if names is None:
+ names = Names()
+ to_dump = []
+ for schema in self.schemas:
+ # Don't print the system error schema
+ if schema.type == 'string':
+ continue
+ to_dump.append(schema.to_json(names))
+ return to_dump
- def to_json(self, names=None):
- if names is None:
- names = Names()
- to_dump = []
- for schema in self.schemas:
- # Don't print the system error schema
- if schema.type == 'string': continue
- to_dump.append(schema.to_json(names))
- return to_dump
class RecordSchema(NamedSchema):
- @staticmethod
- def make_field_objects(field_data, names):
- """We're going to need to make message parameters too."""
- field_objects = []
- field_names = []
- for i, field in enumerate(field_data):
- if callable(getattr(field, 'get', None)):
- type = field.get('type')
- name = field.get('name')
-
- # null values can have a default value of None
- has_default = False
- default = None
- if 'default' in field:
- has_default = True
- default = field.get('default')
-
- order = field.get('order')
- doc = field.get('doc')
- other_props = get_other_props(field, FIELD_RESERVED_PROPS)
- new_field = Field(type, name, has_default, default, order, names, doc,
- other_props)
- # make sure field name has not been used yet
- if new_field.name in field_names:
- fail_msg = 'Field name %s already in use.' % new_field.name
- raise SchemaParseException(fail_msg)
- field_names.append(new_field.name)
- else:
- raise SchemaParseException('Not a valid field: %s' % field)
- field_objects.append(new_field)
- return field_objects
-
- def match(self, writer):
- """Return True if the current schema (as reader) matches the other schema.
-
- @arg writer: the schema to match against
- @return bool
- """
- return writer.type == self.type and (self.type == 'request' or self.check_props(writer, ['fullname']))
-
- def __init__(self, name, namespace, fields, names=None, schema_type='record',
- doc=None, other_props=None):
- # Ensure valid ctor args
- if fields is None:
- fail_msg = 'Record schema requires a non-empty fields property.'
- raise SchemaParseException(fail_msg)
- elif not isinstance(fields, list):
- fail_msg = 'Fields property must be a list of Avro schemas.'
- raise SchemaParseException(fail_msg)
-
- # Call parent ctor (adds own name to namespace, too)
- if schema_type == 'request':
- Schema.__init__(self, schema_type, other_props)
- else:
- NamedSchema.__init__(self, schema_type, name, namespace, names,
- other_props)
-
- if schema_type == 'record':
- old_default = names.default_namespace
- names.default_namespace = Name(name, namespace,
- names.default_namespace).space
-
- # Add class members
- field_objects = RecordSchema.make_field_objects(fields, names)
- self.set_prop('fields', field_objects)
- if doc is not None: self.set_prop('doc', doc)
-
- if schema_type == 'record':
- names.default_namespace = old_default
-
- # read-only properties
- fields = property(lambda self: self.get_prop('fields'))
- doc = property(lambda self: self.get_prop('doc'))
-
- @property
- def fields_dict(self):
- fields_dict = {}
- for field in self.fields:
- fields_dict[field.name] = field
- return fields_dict
-
- def to_json(self, names=None):
- if names is None:
- names = Names()
- # Request records don't have names
- if self.type == 'request':
- return [ f.to_json(names) for f in self.fields ]
-
- if self.fullname in names.names:
- return self.name_ref(names)
- else:
- names.names[self.fullname] = self
-
- to_dump = names.prune_namespace(self.props.copy())
- to_dump['fields'] = [ f.to_json(names) for f in self.fields ]
- return to_dump
-
- def __eq__(self, that):
- to_cmp = json.loads(str(self))
- return to_cmp == json.loads(str(that))
+ @staticmethod
+ def make_field_objects(field_data, names):
+ """We're going to need to make message parameters too."""
+ field_objects = []
+ field_names = []
+ for i, field in enumerate(field_data):
+ if callable(getattr(field, 'get', None)):
+ type = field.get('type')
+ name = field.get('name')
+
+ # null values can have a default value of None
+ has_default = False
+ default = None
+ if 'default' in field:
+ has_default = True
+ default = field.get('default')
+
+ order = field.get('order')
+ doc = field.get('doc')
+ other_props = get_other_props(field, FIELD_RESERVED_PROPS)
+ new_field = Field(type, name, has_default, default, order, names, doc,
+ other_props)
+ # make sure field name has not been used yet
+ if new_field.name in field_names:
+ fail_msg = 'Field name %s already in use.' % new_field.name
+ raise SchemaParseException(fail_msg)
+ field_names.append(new_field.name)
+ else:
+ raise SchemaParseException('Not a valid field: %s' % field)
+ field_objects.append(new_field)
+ return field_objects
+
+ def match(self, writer):
+ """Return True if the current schema (as reader) matches the other schema.
+
+ @arg writer: the schema to match against
+ @return bool
+ """
+ return writer.type == self.type and (self.type == 'request' or self.check_props(writer, ['fullname']))
+
+ def __init__(self, name, namespace, fields, names=None, schema_type='record',
+ doc=None, other_props=None):
+ # Ensure valid ctor args
+ if fields is None:
+ fail_msg = 'Record schema requires a non-empty fields property.'
+ raise SchemaParseException(fail_msg)
+ elif not isinstance(fields, list):
+ fail_msg = 'Fields property must be a list of Avro schemas.'
+ raise SchemaParseException(fail_msg)
+
+ # Call parent ctor (adds own name to namespace, too)
+ if schema_type == 'request':
+ Schema.__init__(self, schema_type, other_props)
+ else:
+ NamedSchema.__init__(self, schema_type, name, namespace, names,
+ other_props)
+
+ if schema_type == 'record':
+ old_default = names.default_namespace
+ names.default_namespace = Name(name, namespace,
+ names.default_namespace).space
+
+ # Add class members
+ field_objects = RecordSchema.make_field_objects(fields, names)
+ self.set_prop('fields', field_objects)
+ if doc is not None:
+ self.set_prop('doc', doc)
+
+ if schema_type == 'record':
+ names.default_namespace = old_default
+
+ # read-only properties
+ fields = property(lambda self: self.get_prop('fields'))
+ doc = property(lambda self: self.get_prop('doc'))
+
+ @property
+ def fields_dict(self):
+ fields_dict = {}
+ for field in self.fields:
+ fields_dict[field.name] = field
+ return fields_dict
+
+ def to_json(self, names=None):
+ if names is None:
+ names = Names()
+ # Request records don't have names
+ if self.type == 'request':
+ return [f.to_json(names) for f in self.fields]
+
+ if self.fullname in names.names:
+ return self.name_ref(names)
+ else:
+ names.names[self.fullname] = self
+
+ to_dump = names.prune_namespace(self.props.copy())
+ to_dump['fields'] = [f.to_json(names) for f in self.fields]
+ return to_dump
+
+ def __eq__(self, that):
+ to_cmp = json.loads(str(self))
+ return to_cmp == json.loads(str(that))
#
@@ -893,209 +929,220 @@ class RecordSchema(NamedSchema):
#
class DateSchema(LogicalSchema, PrimitiveSchema):
- def __init__(self, other_props=None):
- LogicalSchema.__init__(self, constants.DATE)
- PrimitiveSchema.__init__(self, 'int', other_props)
+ def __init__(self, other_props=None):
+ LogicalSchema.__init__(self, constants.DATE)
+ PrimitiveSchema.__init__(self, 'int', other_props)
- def to_json(self, names=None):
- return self.props
+ def to_json(self, names=None):
+ return self.props
- def __eq__(self, that):
- return self.props == that.props
+ def __eq__(self, that):
+ return self.props == that.props
#
# time-millis Type
#
+
class TimeMillisSchema(LogicalSchema, PrimitiveSchema):
- def __init__(self, other_props=None):
- LogicalSchema.__init__(self, constants.TIME_MILLIS)
- PrimitiveSchema.__init__(self, 'int', other_props)
+ def __init__(self, other_props=None):
+ LogicalSchema.__init__(self, constants.TIME_MILLIS)
+ PrimitiveSchema.__init__(self, 'int', other_props)
- def to_json(self, names=None):
- return self.props
+ def to_json(self, names=None):
+ return self.props
- def __eq__(self, that):
- return self.props == that.props
+ def __eq__(self, that):
+ return self.props == that.props
#
# time-micros Type
#
+
class TimeMicrosSchema(LogicalSchema, PrimitiveSchema):
- def __init__(self, other_props=None):
- LogicalSchema.__init__(self, constants.TIME_MICROS)
- PrimitiveSchema.__init__(self, 'long', other_props)
+ def __init__(self, other_props=None):
+ LogicalSchema.__init__(self, constants.TIME_MICROS)
+ PrimitiveSchema.__init__(self, 'long', other_props)
- def to_json(self, names=None):
- return self.props
+ def to_json(self, names=None):
+ return self.props
- def __eq__(self, that):
- return self.props == that.props
+ def __eq__(self, that):
+ return self.props == that.props
#
# timestamp-millis Type
#
+
class TimestampMillisSchema(LogicalSchema, PrimitiveSchema):
- def __init__(self, other_props=None):
- LogicalSchema.__init__(self, constants.TIMESTAMP_MILLIS)
- PrimitiveSchema.__init__(self, 'long', other_props)
+ def __init__(self, other_props=None):
+ LogicalSchema.__init__(self, constants.TIMESTAMP_MILLIS)
+ PrimitiveSchema.__init__(self, 'long', other_props)
- def to_json(self, names=None):
- return self.props
+ def to_json(self, names=None):
+ return self.props
- def __eq__(self, that):
- return self.props == that.props
+ def __eq__(self, that):
+ return self.props == that.props
#
# timestamp-micros Type
#
+
class TimestampMicrosSchema(LogicalSchema, PrimitiveSchema):
- def __init__(self, other_props=None):
- LogicalSchema.__init__(self, constants.TIMESTAMP_MICROS)
- PrimitiveSchema.__init__(self, 'long', other_props)
+ def __init__(self, other_props=None):
+ LogicalSchema.__init__(self, constants.TIMESTAMP_MICROS)
+ PrimitiveSchema.__init__(self, 'long', other_props)
- def to_json(self, names=None):
- return self.props
+ def to_json(self, names=None):
+ return self.props
- def __eq__(self, that):
- return self.props == that.props
+ def __eq__(self, that):
+ return self.props == that.props
#
# Module Methods
#
+
+
def get_other_props(all_props, reserved_props):
- """
- Retrieve the non-reserved properties from a dictionary of properties
- @args reserved_props: The set of reserved properties to exclude
- """
- if callable(getattr(all_props, 'items', None)):
- return {k: v for k, v in all_props.items() if k not in reserved_props}
+ """
+ Retrieve the non-reserved properties from a dictionary of properties
+ @args reserved_props: The set of reserved properties to exclude
+ """
+ if callable(getattr(all_props, 'items', None)):
+ return {k: v for k, v in all_props.items() if k not in reserved_props}
+
def make_bytes_decimal_schema(other_props):
- """Make a BytesDecimalSchema from just other_props."""
- return BytesDecimalSchema(other_props.get('precision'), other_props.get('scale', 0))
+ """Make a BytesDecimalSchema from just other_props."""
+ return BytesDecimalSchema(other_props.get('precision'), other_props.get('scale', 0))
+
def make_logical_schema(logical_type, type_, other_props):
- """Map the logical types to the appropriate literal type and schema class."""
- logical_types = {
- (constants.DATE, 'int'): DateSchema,
- (constants.DECIMAL, 'bytes'): make_bytes_decimal_schema,
- # The fixed decimal schema is handled later by returning None now.
- (constants.DECIMAL, 'fixed'): lambda x: None,
- (constants.TIMESTAMP_MICROS, 'long'): TimestampMicrosSchema,
- (constants.TIMESTAMP_MILLIS, 'long'): TimestampMillisSchema,
- (constants.TIME_MICROS, 'long'): TimeMicrosSchema,
- (constants.TIME_MILLIS, 'int'): TimeMillisSchema,
- }
- try:
- schema_type = logical_types.get((logical_type, type_), None)
- if schema_type is not None:
- return schema_type(other_props)
-
- expected_types = sorted(literal_type for lt, literal_type in logical_types if lt == logical_type)
- if expected_types:
- warnings.warn(
- IgnoredLogicalType("Logical type {} requires literal type {}, not {}.".format(
- logical_type, "/".join(expected_types), type_)))
- else:
- warnings.warn(IgnoredLogicalType("Unknown {}, using {}.".format(logical_type, type_)))
- except IgnoredLogicalType as warning:
- warnings.warn(warning)
- return None
+ """Map the logical types to the appropriate literal type and schema class."""
+ logical_types = {
+ (constants.DATE, 'int'): DateSchema,
+ (constants.DECIMAL, 'bytes'): make_bytes_decimal_schema,
+ # The fixed decimal schema is handled later by returning None now.
+ (constants.DECIMAL, 'fixed'): lambda x: None,
+ (constants.TIMESTAMP_MICROS, 'long'): TimestampMicrosSchema,
+ (constants.TIMESTAMP_MILLIS, 'long'): TimestampMillisSchema,
+ (constants.TIME_MICROS, 'long'): TimeMicrosSchema,
+ (constants.TIME_MILLIS, 'int'): TimeMillisSchema,
+ }
+ try:
+ schema_type = logical_types.get((logical_type, type_), None)
+ if schema_type is not None:
+ return schema_type(other_props)
+
+ expected_types = sorted(literal_type for lt, literal_type in logical_types if lt == logical_type)
+ if expected_types:
+ warnings.warn(
+ IgnoredLogicalType("Logical type {} requires literal type {}, not {}.".format(
+ logical_type, "/".join(expected_types), type_)))
+ else:
+ warnings.warn(IgnoredLogicalType("Unknown {}, using {}.".format(logical_type, type_)))
+ except IgnoredLogicalType as warning:
+ warnings.warn(warning)
+ return None
-def make_avsc_object(json_data, names=None, validate_enum_symbols=True):
- """
- Build Avro Schema from data parsed out of JSON string.
- @arg names: A Names object (tracks seen names and default space)
- @arg validate_enum_symbols: If False, will allow enum symbols that are not valid Avro names.
- """
- if names is None:
- names = Names()
+def make_avsc_object(json_data, names=None, validate_enum_symbols=True):
+ """
+ Build Avro Schema from data parsed out of JSON string.
- # JSON object (non-union)
- if callable(getattr(json_data, 'get', None)):
- type = json_data.get('type')
- other_props = get_other_props(json_data, SCHEMA_RESERVED_PROPS)
- logical_type = json_data.get('logicalType')
- if logical_type:
- logical_schema = make_logical_schema(logical_type, type, other_props or {})
- if logical_schema is not None:
- return logical_schema
- if type in NAMED_TYPES:
- name = json_data.get('name')
- namespace = json_data.get('namespace', names.default_namespace)
- if type == 'fixed':
- size = json_data.get('size')
- if logical_type == 'decimal':
- precision = json_data.get('precision')
- scale = 0 if json_data.get('scale') is None else json_data.get('scale')
- try:
- return FixedDecimalSchema(size, name, precision, scale, namespace, names, other_props)
- except IgnoredLogicalType as warning:
- warnings.warn(warning)
- return FixedSchema(name, namespace, size, names, other_props)
- elif type == 'enum':
- symbols = json_data.get('symbols')
- doc = json_data.get('doc')
- return EnumSchema(name, namespace, symbols, names, doc, other_props, validate_enum_symbols)
- elif type in ['record', 'error']:
- fields = json_data.get('fields')
- doc = json_data.get('doc')
- return RecordSchema(name, namespace, fields, names, type, doc, other_props)
- else:
- raise SchemaParseException('Unknown Named Type: %s' % type)
- if type in PRIMITIVE_TYPES:
- return PrimitiveSchema(type, other_props)
- if type in VALID_TYPES:
- if type == 'array':
- items = json_data.get('items')
- return ArraySchema(items, names, other_props)
- elif type == 'map':
- values = json_data.get('values')
- return MapSchema(values, names, other_props)
- elif type == 'error_union':
- declared_errors = json_data.get('declared_errors')
- return ErrorUnionSchema(declared_errors, names)
- else:
- raise SchemaParseException('Unknown Valid Type: %s' % type)
- elif type is None:
- raise SchemaParseException('No "type" property: %s' % json_data)
+ @arg names: A Names object (tracks seen names and default space)
+ @arg validate_enum_symbols: If False, will allow enum symbols that are not valid Avro names.
+ """
+ if names is None:
+ names = Names()
+
+ # JSON object (non-union)
+ if callable(getattr(json_data, 'get', None)):
+ type = json_data.get('type')
+ other_props = get_other_props(json_data, SCHEMA_RESERVED_PROPS)
+ logical_type = json_data.get('logicalType')
+ if logical_type:
+ logical_schema = make_logical_schema(logical_type, type, other_props or {})
+ if logical_schema is not None:
+ return logical_schema
+ if type in NAMED_TYPES:
+ name = json_data.get('name')
+ namespace = json_data.get('namespace', names.default_namespace)
+ if type == 'fixed':
+ size = json_data.get('size')
+ if logical_type == 'decimal':
+ precision = json_data.get('precision')
+ scale = 0 if json_data.get('scale') is None else json_data.get('scale')
+ try:
+ return FixedDecimalSchema(size, name, precision, scale, namespace, names, other_props)
+ except IgnoredLogicalType as warning:
+ warnings.warn(warning)
+ return FixedSchema(name, namespace, size, names, other_props)
+ elif type == 'enum':
+ symbols = json_data.get('symbols')
+ doc = json_data.get('doc')
+ return EnumSchema(name, namespace, symbols, names, doc, other_props, validate_enum_symbols)
+ elif type in ['record', 'error']:
+ fields = json_data.get('fields')
+ doc = json_data.get('doc')
+ return RecordSchema(name, namespace, fields, names, type, doc, other_props)
+ else:
+ raise SchemaParseException('Unknown Named Type: %s' % type)
+ if type in PRIMITIVE_TYPES:
+ return PrimitiveSchema(type, other_props)
+ if type in VALID_TYPES:
+ if type == 'array':
+ items = json_data.get('items')
+ return ArraySchema(items, names, other_props)
+ elif type == 'map':
+ values = json_data.get('values')
+ return MapSchema(values, names, other_props)
+ elif type == 'error_union':
+ declared_errors = json_data.get('declared_errors')
+ return ErrorUnionSchema(declared_errors, names)
+ else:
+ raise SchemaParseException('Unknown Valid Type: %s' % type)
+ elif type is None:
+ raise SchemaParseException('No "type" property: %s' % json_data)
+ else:
+ raise SchemaParseException('Undefined type: %s' % type)
+ # JSON array (union)
+ elif isinstance(json_data, list):
+ return UnionSchema(json_data, names)
+ # JSON string (primitive)
+ elif json_data in PRIMITIVE_TYPES:
+ return PrimitiveSchema(json_data)
+ # not for us!
else:
- raise SchemaParseException('Undefined type: %s' % type)
- # JSON array (union)
- elif isinstance(json_data, list):
- return UnionSchema(json_data, names)
- # JSON string (primitive)
- elif json_data in PRIMITIVE_TYPES:
- return PrimitiveSchema(json_data)
- # not for us!
- else:
- fail_msg = "Could not make an Avro Schema object from %s." % json_data
- raise SchemaParseException(fail_msg)
+ fail_msg = "Could not make an Avro Schema object from %s." % json_data
+ raise SchemaParseException(fail_msg)
# TODO(hammer): make method for reading from a file?
+
+
def parse(json_string, validate_enum_symbols=True):
- """Constructs the Schema from the JSON text.
-
- @arg validate_enum_symbols: If False, will allow enum symbols that are not valid Avro names.
- """
- # parse the JSON
- try:
- json_data = json.loads(json_string)
- except Exception as e:
- msg = 'Error parsing JSON: {}, error = {}'.format(json_string, e)
- new_exception = SchemaParseException(msg)
- traceback = sys.exc_info()[2]
- if not hasattr(new_exception, 'with_traceback'):
- raise (new_exception, None, traceback) # Python 2 syntax
- raise new_exception.with_traceback(traceback)
-
- # Initialize the names object
- names = Names()
-
- # construct the Avro Schema object
- return make_avsc_object(json_data, names, validate_enum_symbols)
+ """Constructs the Schema from the JSON text.
+
+ @arg validate_enum_symbols: If False, will allow enum symbols that are not valid Avro names.
+ """
+ # parse the JSON
+ try:
+ json_data = json.loads(json_string)
+ except Exception as e:
+ msg = 'Error parsing JSON: {}, error = {}'.format(json_string, e)
+ new_exception = SchemaParseException(msg)
+ traceback = sys.exc_info()[2]
+ if not hasattr(new_exception, 'with_traceback'):
+ raise (new_exception, None, traceback) # Python 2 syntax
+ raise new_exception.with_traceback(traceback)
+
+ # Initialize the names object
+ names = Names()
+
+ # construct the Avro Schema object
+ return make_avsc_object(json_data, names, validate_enum_symbols)
diff --git a/lang/py/avro/test/av_bench.py b/lang/py/avro/test/av_bench.py
index 96c95d9..bc94ae4 100644
--- a/lang/py/avro/test/av_bench.py
+++ b/lang/py/avro/test/av_bench.py
@@ -30,14 +30,17 @@ import avro.schema
types = ["A", "CNAME"]
+
def rand_name():
return ''.join(sample(string.ascii_lowercase, 15))
+
def rand_ip():
- return "%s.%s.%s.%s" %(randint(0,255), randint(0,255), randint(0,255), randint(0,255))
+ return "%s.%s.%s.%s" % (randint(0, 255), randint(0, 255), randint(0, 255), randint(0, 255))
+
def write(n):
- schema_s="""
+ schema_s = """
{ "type": "record",
"name": "Query",
"fields" : [
@@ -45,11 +48,11 @@ def write(n):
{"name": "response", "type": "string"},
{"name": "type", "type": "string", "default": "A"}
]}"""
- out = open("datafile.avr",'w')
+ out = open("datafile.avr", 'w')
schema = avro.schema.parse(schema_s)
writer = avro.io.DatumWriter(schema)
- dw = avro.datafile.DataFileWriter(out, writer, schema) #,codec='deflate')
+ dw = avro.datafile.DataFileWriter(out, writer, schema) # ,codec='deflate')
for _ in xrange(n):
response = rand_ip()
query = rand_name()
@@ -58,20 +61,23 @@ def write(n):
dw.close()
+
def read():
f = open("datafile.avr")
reader = avro.io.DatumReader()
- af=avro.datafile.DataFileReader(f,reader)
+ af = avro.datafile.DataFileReader(f, reader)
- x=0
+ x = 0
for _ in af:
pass
+
def t(f, *args):
s = time.time()
f(*args)
e = time.time()
- return e-s
+ return e - s
+
if __name__ == "__main__":
n = int(sys.argv[1])
diff --git a/lang/py/avro/test/gen_interop_data.py b/lang/py/avro/test/gen_interop_data.py
index c4bc65b..1bd5c76 100644
--- a/lang/py/avro/test/gen_interop_data.py
+++ b/lang/py/avro/test/gen_interop_data.py
@@ -29,43 +29,45 @@ import avro.schema
from avro.codecs import Codecs
try:
- unicode
+ unicode
except NameError:
- unicode = str
+ unicode = str
NULL_CODEC = 'null'
CODECS_TO_VALIDATE = Codecs.supported_codec_names()
DATUM = {
- 'intField': 12,
- 'longField': 15234324,
- 'stringField': unicode('hey'),
- 'boolField': True,
- 'floatField': 1234.0,
- 'doubleField': -1234.0,
- 'bytesField': b'12312adf',
- 'nullField': None,
- 'arrayField': [5.0, 0.0, 12.0],
- 'mapField': {unicode('a'): {'label': unicode('a')},
- unicode('bee'): {'label': unicode('cee')}},
- 'unionField': 12.0,
- 'enumField': 'C',
- 'fixedField': b'1019181716151413',
- 'recordField': {'label': unicode('blah'),
- 'children': [{'label': unicode('inner'), 'children': []}]},
+ 'intField': 12,
+ 'longField': 15234324,
+ 'stringField': unicode('hey'),
+ 'boolField': True,
+ 'floatField': 1234.0,
+ 'doubleField': -1234.0,
+ 'bytesField': b'12312adf',
+ 'nullField': None,
+ 'arrayField': [5.0, 0.0, 12.0],
+ 'mapField': {unicode('a'): {'label': unicode('a')},
+ unicode('bee'): {'label': unicode('cee')}},
+ 'unionField': 12.0,
+ 'enumField': 'C',
+ 'fixedField': b'1019181716151413',
+ 'recordField': {'label': unicode('blah'),
+ 'children': [{'label': unicode('inner'), 'children': []}]},
}
+
def generate(schema_path, output_path):
- with open(schema_path, 'r') as schema_file:
- interop_schema = avro.schema.parse(schema_file.read())
- for codec in CODECS_TO_VALIDATE:
- filename = output_path
- if codec != NULL_CODEC:
- base, ext = os.path.splitext(output_path)
- filename = base + "_" + codec + ext
- with avro.datafile.DataFileWriter(open(filename, 'wb'), avro.io.DatumWriter(),
- interop_schema, codec=codec) as dfw:
- dfw.append(DATUM)
+ with open(schema_path, 'r') as schema_file:
+ interop_schema = avro.schema.parse(schema_file.read())
+ for codec in CODECS_TO_VALIDATE:
+ filename = output_path
+ if codec != NULL_CODEC:
+ base, ext = os.path.splitext(output_path)
+ filename = base + "_" + codec + ext
+ with avro.datafile.DataFileWriter(open(filename, 'wb'), avro.io.DatumWriter(),
+ interop_schema, codec=codec) as dfw:
+ dfw.append(DATUM)
+
if __name__ == "__main__":
- generate(sys.argv[1], sys.argv[2])
+ generate(sys.argv[1], sys.argv[2])
diff --git a/lang/py/avro/test/mock_tether_parent.py b/lang/py/avro/test/mock_tether_parent.py
index dac958d..ac7a08d 100644
--- a/lang/py/avro/test/mock_tether_parent.py
+++ b/lang/py/avro/test/mock_tether_parent.py
@@ -27,66 +27,71 @@ import avro.tether.util
from avro import ipc, protocol
try:
- import BaseHTTPServer as http_server # type: ignore
+ import BaseHTTPServer as http_server # type: ignore
except ImportError:
- import http.server as http_server # type: ignore
+ import http.server as http_server # type: ignore
SERVER_ADDRESS = ('localhost', avro.tether.util.find_port())
+
class MockParentResponder(ipc.Responder):
- """
- The responder for the mocked parent
- """
- def __init__(self):
- ipc.Responder.__init__(self, avro.tether.tether_task.outputProtocol)
+ """
+ The responder for the mocked parent
+ """
+
+ def __init__(self):
+ ipc.Responder.__init__(self, avro.tether.tether_task.outputProtocol)
+
+ def invoke(self, message, request):
+ if message.name == 'configure':
+ print("MockParentResponder: Recieved 'configure': inputPort={0}".format(request["port"]))
- def invoke(self, message, request):
- if message.name=='configure':
- print("MockParentResponder: Recieved 'configure': inputPort={0}".format(request["port"]))
+ elif message.name == 'status':
+ print("MockParentResponder: Recieved 'status': message={0}".format(request["message"]))
+ elif message.name == 'fail':
+ print("MockParentResponder: Recieved 'fail': message={0}".format(request["message"]))
+ else:
+ print("MockParentResponder: Recieved {0}".format(message.name))
- elif message.name=='status':
- print("MockParentResponder: Recieved 'status': message={0}".format(request["message"]))
- elif message.name=='fail':
- print("MockParentResponder: Recieved 'fail': message={0}".format(request["message"]))
- else:
- print("MockParentResponder: Recieved {0}".format(message.name))
+ # flush the output so it shows up in the parent process
+ sys.stdout.flush()
- # flush the output so it shows up in the parent process
- sys.stdout.flush()
+ return None
- return None
class MockParentHandler(http_server.BaseHTTPRequestHandler):
- """Create a handler for the parent.
- """
- def do_POST(self):
- self.responder = MockParentResponder()
- call_request_reader = ipc.FramedReader(self.rfile)
- call_request = call_request_reader.read_framed_message()
- resp_body = self.responder.respond(call_request)
- self.send_response(200)
- self.send_header('Content-Type', 'avro/binary')
- self.end_headers()
- resp_writer = ipc.FramedWriter(self.wfile)
- resp_writer.write_framed_message(resp_body)
+ """Create a handler for the parent.
+ """
+
+ def do_POST(self):
+ self.responder = MockParentResponder()
+ call_request_reader = ipc.FramedReader(self.rfile)
+ call_request = call_request_reader.read_framed_message()
+ resp_body = self.responder.respond(call_request)
+ self.send_response(200)
+ self.send_header('Content-Type', 'avro/binary')
+ self.end_headers()
+ resp_writer = ipc.FramedWriter(self.wfile)
+ resp_writer.write_framed_message(resp_body)
+
if __name__ == '__main__':
- if (len(sys.argv)<=1):
- raise ValueError("Usage: mock_tether_parent command")
-
- cmd=sys.argv[1].lower()
- if (sys.argv[1]=='start_server'):
- if (len(sys.argv)==3):
- port=int(sys.argv[2])
- else:
- raise ValueError("Usage: mock_tether_parent start_server port")
-
- SERVER_ADDRESS=(SERVER_ADDRESS[0],port)
- print("mock_tether_parent: Launching Server on Port: {0}".format(SERVER_ADDRESS[1]))
-
- # flush the output so it shows up in the parent process
- sys.stdout.flush()
- parent_server = http_server.HTTPServer(SERVER_ADDRESS, MockParentHandler)
- parent_server.allow_reuse_address = True
- parent_server.serve_forever()
+ if (len(sys.argv) <= 1):
+ raise ValueError("Usage: mock_tether_parent command")
+
+ cmd = sys.argv[1].lower()
+ if (sys.argv[1] == 'start_server'):
+ if (len(sys.argv) == 3):
+ port = int(sys.argv[2])
+ else:
+ raise ValueError("Usage: mock_tether_parent start_server port")
+
+ SERVER_ADDRESS = (SERVER_ADDRESS[0], port)
+ print("mock_tether_parent: Launching Server on Port: {0}".format(SERVER_ADDRESS[1]))
+
+ # flush the output so it shows up in the parent process
+ sys.stdout.flush()
+ parent_server = http_server.HTTPServer(SERVER_ADDRESS, MockParentHandler)
+ parent_server.allow_reuse_address = True
+ parent_server.serve_forever()
diff --git a/lang/py/avro/test/sample_http_client.py b/lang/py/avro/test/sample_http_client.py
index 02b8421..4e9b881 100644
--- a/lang/py/avro/test/sample_http_client.py
+++ b/lang/py/avro/test/sample_http_client.py
@@ -54,43 +54,47 @@ MAIL_PROTOCOL = protocol.parse(MAIL_PROTOCOL_JSON)
SERVER_HOST = 'localhost'
SERVER_PORT = 9090
+
class UsageError(Exception):
- def __init__(self, value):
- self.value = value
- def __str__(self):
- return repr(self.value)
+ def __init__(self, value):
+ self.value = value
+
+ def __str__(self):
+ return repr(self.value)
+
def make_requestor(server_host, server_port, protocol):
- client = ipc.HTTPTransceiver(SERVER_HOST, SERVER_PORT)
- return ipc.Requestor(protocol, client)
+ client = ipc.HTTPTransceiver(SERVER_HOST, SERVER_PORT)
+ return ipc.Requestor(protocol, client)
+
if __name__ == '__main__':
- if len(sys.argv) not in [4, 5]:
- raise UsageError("Usage: <to> <from> <body> [<count>]")
-
- # client code - attach to the server and send a message
- # fill in the Message record
- message = dict()
- message['to'] = sys.argv[1]
- message['from'] = sys.argv[2]
- message['body'] = sys.argv[3]
-
- try:
- num_messages = int(sys.argv[4])
- except IndexError:
- num_messages = 1
-
- # build the parameters for the request
- params = {}
- params['message'] = message
-
- # send the requests and print the result
- for msg_count in range(num_messages):
- requestor = make_requestor(SERVER_HOST, SERVER_PORT, MAIL_PROTOCOL)
- result = requestor.request('send', params)
- print("Result: " + result)
+ if len(sys.argv) not in [4, 5]:
+ raise UsageError("Usage: <to> <from> <body> [<count>]")
- # try out a replay message
- requestor = make_requestor(SERVER_HOST, SERVER_PORT, MAIL_PROTOCOL)
- result = requestor.request('replay', dict())
- print("Replay Result: " + result)
+ # client code - attach to the server and send a message
+ # fill in the Message record
+ message = dict()
+ message['to'] = sys.argv[1]
+ message['from'] = sys.argv[2]
+ message['body'] = sys.argv[3]
+
+ try:
+ num_messages = int(sys.argv[4])
+ except IndexError:
+ num_messages = 1
+
+ # build the parameters for the request
+ params = {}
+ params['message'] = message
+
+ # send the requests and print the result
+ for msg_count in range(num_messages):
+ requestor = make_requestor(SERVER_HOST, SERVER_PORT, MAIL_PROTOCOL)
+ result = requestor.request('send', params)
+ print("Result: " + result)
+
+ # try out a replay message
+ requestor = make_requestor(SERVER_HOST, SERVER_PORT, MAIL_PROTOCOL)
+ result = requestor.request('replay', dict())
+ print("Replay Result: " + result)
diff --git a/lang/py/avro/test/sample_http_server.py b/lang/py/avro/test/sample_http_server.py
index 4fa05c3..387f907 100644
--- a/lang/py/avro/test/sample_http_server.py
+++ b/lang/py/avro/test/sample_http_server.py
@@ -23,9 +23,9 @@ import avro.ipc
import avro.protocol
try:
- import BaseHTTPServer as http_server # type: ignore
+ import BaseHTTPServer as http_server # type: ignore
except ImportError:
- import http.server as http_server # type: ignore
+ import http.server as http_server # type: ignore
MAIL_PROTOCOL_JSON = """\
{"namespace": "example.proto",
@@ -56,32 +56,35 @@ MAIL_PROTOCOL_JSON = """\
MAIL_PROTOCOL = avro.protocol.parse(MAIL_PROTOCOL_JSON)
SERVER_ADDRESS = ('localhost', 9090)
+
class MailResponder(avro.ipc.Responder):
- def __init__(self):
- avro.ipc.Responder.__init__(self, MAIL_PROTOCOL)
+ def __init__(self):
+ avro.ipc.Responder.__init__(self, MAIL_PROTOCOL)
+
+ def invoke(self, message, request):
+ if message.name == 'send':
+ request_content = request['message']
+ response = "Sent message to %(to)s from %(from)s with body %(body)s" % \
+ request_content
+ return response
+ elif message.name == 'replay':
+ return 'replay'
- def invoke(self, message, request):
- if message.name == 'send':
- request_content = request['message']
- response = "Sent message to %(to)s from %(from)s with body %(body)s" % \
- request_content
- return response
- elif message.name == 'replay':
- return 'replay'
class MailHandler(http_server.BaseHTTPRequestHandler):
- def do_POST(self):
- self.responder = MailResponder()
- call_request_reader = avro.ipc.FramedReader(self.rfile)
- call_request = call_request_reader.read_framed_message()
- resp_body = self.responder.respond(call_request)
- self.send_response(200)
- self.send_header('Content-Type', 'avro/binary')
- self.end_headers()
- resp_writer = avro.ipc.FramedWriter(self.wfile)
- resp_writer.write_framed_message(resp_body)
+ def do_POST(self):
+ self.responder = MailResponder()
+ call_request_reader = avro.ipc.FramedReader(self.rfile)
+ call_request = call_request_reader.read_framed_message()
+ resp_body = self.responder.respond(call_request)
+ self.send_response(200)
+ self.send_header('Content-Type', 'avro/binary')
+ self.end_headers()
+ resp_writer = avro.ipc.FramedWriter(self.wfile)
+ resp_writer.write_framed_message(resp_body)
+
if __name__ == '__main__':
- mail_server = http_server.HTTPServer(SERVER_ADDRESS, MailHandler)
- mail_server.allow_reuse_address = True
- mail_server.serve_forever()
+ mail_server = http_server.HTTPServer(SERVER_ADDRESS, MailHandler)
+ mail_server.allow_reuse_address = True
+ mail_server.serve_forever()
diff --git a/lang/py/avro/test/test_datafile.py b/lang/py/avro/test/test_datafile.py
index dad56ef..924a949 100644
--- a/lang/py/avro/test/test_datafile.py
+++ b/lang/py/avro/test/test_datafile.py
@@ -26,33 +26,33 @@ from avro import datafile, io, schema
from avro.codecs import Codecs
try:
- unicode
+ unicode
except NameError:
- unicode = str
+ unicode = str
SCHEMAS_TO_VALIDATE = (
- ('"null"', None),
- ('"boolean"', True),
- ('"string"', unicode('adsfasdf09809dsf-=adsf')),
- ('"bytes"', b'12345abcd'),
- ('"int"', 1234),
- ('"long"', 1234),
- ('"float"', 1234.0),
- ('"double"', 1234.0),
- ('{"type": "fixed", "name": "Test", "size": 1}', b'B'),
- ('{"type": "enum", "name": "Test", "symbols": ["A", "B"]}', 'B'),
- ('{"type": "array", "items": "long"}', [1, 3, 2]),
- ('{"type": "map", "values": "long"}', {unicode('a'): 1,
- unicode('b'): 3,
- unicode('c'): 2}),
- ('["string", "null", "long"]', None),
- ("""\
+ ('"null"', None),
+ ('"boolean"', True),
+ ('"string"', unicode('adsfasdf09809dsf-=adsf')),
+ ('"bytes"', b'12345abcd'),
+ ('"int"', 1234),
+ ('"long"', 1234),
+ ('"float"', 1234.0),
+ ('"double"', 1234.0),
+ ('{"type": "fixed", "name": "Test", "size": 1}', b'B'),
+ ('{"type": "enum", "name": "Test", "symbols": ["A", "B"]}', 'B'),
+ ('{"type": "array", "items": "long"}', [1, 3, 2]),
+ ('{"type": "map", "values": "long"}', {unicode('a'): 1,
+ unicode('b'): 3,
+ unicode('c'): 2}),
+ ('["string", "null", "long"]', None),
+ ("""\
{"type": "record",
"name": "Test",
"fields": [{"name": "f", "type": "long"}]}
""", {'f': 5}),
- ("""\
+ ("""\
{"type": "record",
"name": "Lisp",
"fields": [{"name": "value",
@@ -67,152 +67,156 @@ SCHEMAS_TO_VALIDATE = (
FILENAME = 'test_datafile.out'
CODECS_TO_VALIDATE = Codecs.supported_codec_names()
+
class TestDataFile(unittest.TestCase):
- def test_round_trip(self):
- print('')
- print('TEST ROUND TRIP')
- print('===============')
- print('')
- correct = 0
- for i, (example_schema, datum) in enumerate(SCHEMAS_TO_VALIDATE):
- for codec in CODECS_TO_VALIDATE:
+ def test_round_trip(self):
print('')
- print('SCHEMA NUMBER %d' % (i + 1))
- print('================')
+ print('TEST ROUND TRIP')
+ print('===============')
print('')
- print('Schema: %s' % example_schema)
- print('Datum: %s' % datum)
- print('Codec: %s' % codec)
-
- # write data in binary to file 10 times
+ correct = 0
+ for i, (example_schema, datum) in enumerate(SCHEMAS_TO_VALIDATE):
+ for codec in CODECS_TO_VALIDATE:
+ print('')
+ print('SCHEMA NUMBER %d' % (i + 1))
+ print('================')
+ print('')
+ print('Schema: %s' % example_schema)
+ print('Datum: %s' % datum)
+ print('Codec: %s' % codec)
+
+ # write data in binary to file 10 times
+ writer = open(FILENAME, 'wb')
+ datum_writer = io.DatumWriter()
+ schema_object = schema.parse(example_schema)
+ dfw = datafile.DataFileWriter(writer, datum_writer, schema_object, codec=codec)
+ for i in range(10):
+ dfw.append(datum)
+ dfw.close()
+
+ # read data in binary from file
+ reader = open(FILENAME, 'rb')
+ datum_reader = io.DatumReader()
+ dfr = datafile.DataFileReader(reader, datum_reader)
+ round_trip_data = []
+ for datum in dfr:
+ round_trip_data.append(datum)
+
+ print('Round Trip Data: %s' % round_trip_data)
+ print('Round Trip Data Length: %d' % len(round_trip_data))
+ is_correct = [datum] * 10 == round_trip_data
+ if is_correct:
+ correct += 1
+ print('Correct Round Trip: %s' % is_correct)
+ print('')
+ os.remove(FILENAME)
+ self.assertEquals(correct, len(CODECS_TO_VALIDATE) * len(SCHEMAS_TO_VALIDATE))
+
+ def test_append(self):
+ print('')
+ print('TEST APPEND')
+ print('===========')
+ print('')
+ correct = 0
+ for i, (example_schema, datum) in enumerate(SCHEMAS_TO_VALIDATE):
+ for codec in CODECS_TO_VALIDATE:
+ print('')
+ print('SCHEMA NUMBER %d' % (i + 1))
+ print('================')
+ print('')
+ print('Schema: %s' % example_schema)
+ print('Datum: %s' % datum)
+ print('Codec: %s' % codec)
+
+ # write data in binary to file once
+ writer = open(FILENAME, 'wb')
+ datum_writer = io.DatumWriter()
+ schema_object = schema.parse(example_schema)
+ dfw = datafile.DataFileWriter(writer, datum_writer, schema_object, codec=codec)
+ dfw.append(datum)
+ dfw.close()
+
+ # open file, write, and close nine times
+ for i in range(9):
+ writer = open(FILENAME, 'ab+')
+ dfw = datafile.DataFileWriter(writer, io.DatumWriter())
+ dfw.append(datum)
+ dfw.close()
+
+ # read data in binary from file
+ reader = open(FILENAME, 'rb')
+ datum_reader = io.DatumReader()
+ dfr = datafile.DataFileReader(reader, datum_reader)
+ appended_data = []
+ for datum in dfr:
+ appended_data.append(datum)
+
+ print('Appended Data: %s' % appended_data)
+ print('Appended Data Length: %d' % len(appended_data))
+ is_correct = [datum] * 10 == appended_data
+ if is_correct:
+ correct += 1
+ print('Correct Appended: %s' % is_correct)
+ print('')
+ os.remove(FILENAME)
+ self.assertEquals(correct, len(CODECS_TO_VALIDATE) * len(SCHEMAS_TO_VALIDATE))
+
+ def test_context_manager(self):
+ """Test the writer with a 'with' statement."""
writer = open(FILENAME, 'wb')
datum_writer = io.DatumWriter()
- schema_object = schema.parse(example_schema)
- dfw = datafile.DataFileWriter(writer, datum_writer, schema_object, codec=codec)
- for i in range(10):
- dfw.append(datum)
- dfw.close()
-
- # read data in binary from file
+ sample_schema, sample_datum = SCHEMAS_TO_VALIDATE[1]
+ schema_object = schema.parse(sample_schema)
+ with datafile.DataFileWriter(writer, datum_writer, schema_object) as dfw:
+ dfw.append(sample_datum)
+ self.assertTrue(writer.closed)
+
+ # Test the reader with a 'with' statement.
+ datums = []
reader = open(FILENAME, 'rb')
datum_reader = io.DatumReader()
- dfr = datafile.DataFileReader(reader, datum_reader)
- round_trip_data = []
- for datum in dfr:
- round_trip_data.append(datum)
-
- print('Round Trip Data: %s' % round_trip_data)
- print('Round Trip Data Length: %d' % len(round_trip_data))
- is_correct = [datum] * 10 == round_trip_data
- if is_correct: correct += 1
- print('Correct Round Trip: %s' % is_correct)
- print('')
- os.remove(FILENAME)
- self.assertEquals(correct, len(CODECS_TO_VALIDATE)*len(SCHEMAS_TO_VALIDATE))
-
- def test_append(self):
- print('')
- print('TEST APPEND')
- print('===========')
- print('')
- correct = 0
- for i, (example_schema, datum) in enumerate(SCHEMAS_TO_VALIDATE):
- for codec in CODECS_TO_VALIDATE:
- print('')
- print('SCHEMA NUMBER %d' % (i + 1))
- print('================')
- print('')
- print('Schema: %s' % example_schema)
- print('Datum: %s' % datum)
- print('Codec: %s' % codec)
+ with datafile.DataFileReader(reader, datum_reader) as dfr:
+ for datum in dfr:
+ datums.append(datum)
+ self.assertTrue(reader.closed)
- # write data in binary to file once
+ def test_metadata(self):
+ # Test the writer with a 'with' statement.
writer = open(FILENAME, 'wb')
datum_writer = io.DatumWriter()
- schema_object = schema.parse(example_schema)
- dfw = datafile.DataFileWriter(writer, datum_writer, schema_object, codec=codec)
- dfw.append(datum)
- dfw.close()
-
- # open file, write, and close nine times
- for i in range(9):
- writer = open(FILENAME, 'ab+')
- dfw = datafile.DataFileWriter(writer, io.DatumWriter())
- dfw.append(datum)
- dfw.close()
-
- # read data in binary from file
+ sample_schema, sample_datum = SCHEMAS_TO_VALIDATE[1]
+ schema_object = schema.parse(sample_schema)
+ with datafile.DataFileWriter(writer, datum_writer, schema_object) as dfw:
+ dfw.set_meta('test.string', b'foo')
+ dfw.set_meta('test.number', b'1')
+ dfw.append(sample_datum)
+ self.assertTrue(writer.closed)
+
+ # Test the reader with a 'with' statement.
+ datums = []
reader = open(FILENAME, 'rb')
datum_reader = io.DatumReader()
- dfr = datafile.DataFileReader(reader, datum_reader)
- appended_data = []
- for datum in dfr:
- appended_data.append(datum)
-
- print('Appended Data: %s' % appended_data)
- print('Appended Data Length: %d' % len(appended_data))
- is_correct = [datum] * 10 == appended_data
- if is_correct: correct += 1
- print('Correct Appended: %s' % is_correct)
- print('')
- os.remove(FILENAME)
- self.assertEquals(correct, len(CODECS_TO_VALIDATE)*len(SCHEMAS_TO_VALIDATE))
-
- def test_context_manager(self):
- """Test the writer with a 'with' statement."""
- writer = open(FILENAME, 'wb')
- datum_writer = io.DatumWriter()
- sample_schema, sample_datum = SCHEMAS_TO_VALIDATE[1]
- schema_object = schema.parse(sample_schema)
- with datafile.DataFileWriter(writer, datum_writer, schema_object) as dfw:
- dfw.append(sample_datum)
- self.assertTrue(writer.closed)
-
- # Test the reader with a 'with' statement.
- datums = []
- reader = open(FILENAME, 'rb')
- datum_reader = io.DatumReader()
- with datafile.DataFileReader(reader, datum_reader) as dfr:
- for datum in dfr:
- datums.append(datum)
- self.assertTrue(reader.closed)
-
- def test_metadata(self):
- # Test the writer with a 'with' statement.
- writer = open(FILENAME, 'wb')
- datum_writer = io.DatumWriter()
- sample_schema, sample_datum = SCHEMAS_TO_VALIDATE[1]
- schema_object = schema.parse(sample_schema)
- with datafile.DataFileWriter(writer, datum_writer, schema_object) as dfw:
- dfw.set_meta('test.string', b'foo')
- dfw.set_meta('test.number', b'1')
- dfw.append(sample_datum)
- self.assertTrue(writer.closed)
-
- # Test the reader with a 'with' statement.
- datums = []
- reader = open(FILENAME, 'rb')
- datum_reader = io.DatumReader()
- with datafile.DataFileReader(reader, datum_reader) as dfr:
- self.assertEquals(b'foo', dfr.get_meta('test.string'))
- self.assertEquals(b'1', dfr.get_meta('test.number'))
- for datum in dfr:
- datums.append(datum)
- self.assertTrue(reader.closed)
-
- def test_empty_datafile(self):
- """A reader should not fail to read a file consisting of a single empty block."""
- sample_schema = schema.parse(SCHEMAS_TO_VALIDATE[1][0])
- with datafile.DataFileWriter(open(FILENAME, 'wb'), io.DatumWriter(),
- sample_schema) as dfw:
- dfw.flush()
- # Write an empty block
- dfw.encoder.write_long(0)
- dfw.encoder.write_long(0)
- dfw.writer.write(dfw.sync_marker)
-
- with datafile.DataFileReader(open(FILENAME, 'rb'), io.DatumReader()) as dfr:
- self.assertEqual([], list(dfr))
+ with datafile.DataFileReader(reader, datum_reader) as dfr:
+ self.assertEquals(b'foo', dfr.get_meta('test.string'))
+ self.assertEquals(b'1', dfr.get_meta('test.number'))
+ for datum in dfr:
+ datums.append(datum)
+ self.assertTrue(reader.closed)
+
+ def test_empty_datafile(self):
+ """A reader should not fail to read a file consisting of a single empty block."""
+ sample_schema = schema.parse(SCHEMAS_TO_VALIDATE[1][0])
+ with datafile.DataFileWriter(open(FILENAME, 'wb'), io.DatumWriter(),
+ sample_schema) as dfw:
+ dfw.flush()
+ # Write an empty block
+ dfw.encoder.write_long(0)
+ dfw.encoder.write_long(0)
+ dfw.writer.write(dfw.sync_marker)
+
+ with datafile.DataFileReader(open(FILENAME, 'rb'), io.DatumReader()) as dfr:
+ self.assertEqual([], list(dfr))
+
if __name__ == '__main__':
- unittest.main()
+ unittest.main()
diff --git a/lang/py/avro/test/test_datafile_interop.py b/lang/py/avro/test/test_datafile_interop.py
index 2f2ac2b..1b45945 100644
--- a/lang/py/avro/test/test_datafile_interop.py
+++ b/lang/py/avro/test/test_datafile_interop.py
@@ -30,27 +30,28 @@ _INTEROP_DATA_DIR = os.path.join(os.path.dirname(avro.__file__), 'test', 'intero
@unittest.skipUnless(os.path.exists(_INTEROP_DATA_DIR),
"{} does not exist".format(_INTEROP_DATA_DIR))
class TestDataFileInterop(unittest.TestCase):
- def test_interop(self):
- """Test Interop"""
- for f in os.listdir(_INTEROP_DATA_DIR):
- filename = os.path.join(_INTEROP_DATA_DIR, f)
- assert os.stat(filename).st_size > 0
- base_ext = os.path.splitext(os.path.basename(f))[0].split('_', 1)
- if len(base_ext) < 2 or base_ext[1] in datafile.VALID_CODECS:
- print('READING %s' % f)
- print()
-
- # read data in binary from file
- datum_reader = io.DatumReader()
- with open(filename, 'rb') as reader:
- dfr = datafile.DataFileReader(reader, datum_reader)
- i = 0
- for i, datum in enumerate(dfr, 1):
- assert datum is not None
- assert i > 0
- else:
- print('SKIPPING %s due to an unsupported codec' % f)
- print()
+ def test_interop(self):
+ """Test Interop"""
+ for f in os.listdir(_INTEROP_DATA_DIR):
+ filename = os.path.join(_INTEROP_DATA_DIR, f)
+ assert os.stat(filename).st_size > 0
+ base_ext = os.path.splitext(os.path.basename(f))[0].split('_', 1)
+ if len(base_ext) < 2 or base_ext[1] in datafile.VALID_CODECS:
+ print('READING %s' % f)
+ print()
+
+ # read data in binary from file
+ datum_reader = io.DatumReader()
+ with open(filename, 'rb') as reader:
+ dfr = datafile.DataFileReader(reader, datum_reader)
+ i = 0
+ for i, datum in enumerate(dfr, 1):
+ assert datum is not None
+ assert i > 0
+ else:
+ print('SKIPPING %s due to an unsupported codec' % f)
+ print()
+
if __name__ == '__main__':
- unittest.main()
+ unittest.main()
diff --git a/lang/py/avro/test/test_init.py b/lang/py/avro/test/test_init.py
index cd423a9..f41297f 100644
--- a/lang/py/avro/test/test_init.py
+++ b/lang/py/avro/test/test_init.py
@@ -24,10 +24,10 @@ import avro
class TestVersion(unittest.TestCase):
- def test_import_version(self):
+ def test_import_version(self):
+ # make sure we have __version__ attribute in avro module
+ self.assertTrue(hasattr(avro, '__version__'))
- # make sure we have __version__ attribute in avro module
- avro.__version__
if __name__ == '__main__':
- unittest.main()
+ unittest.main()
diff --git a/lang/py/avro/test/test_io.py b/lang/py/avro/test/test_io.py
index 7f21ae3..0a93db2 100644
--- a/lang/py/avro/test/test_io.py
+++ b/lang/py/avro/test/test_io.py
@@ -29,70 +29,70 @@ import avro.io
from avro import schema, timezones
try:
- unicode
+ unicode
except NameError:
- unicode = str
+ unicode = str
SCHEMAS_TO_VALIDATE = (
- ('"null"', None),
- ('"boolean"', True),
- ('"string"', unicode('adsfasdf09809dsf-=adsf')),
- ('"bytes"', b'12345abcd'),
- ('"int"', 1234),
- ('"long"', 1234),
- ('"float"', 1234.0),
- ('"double"', 1234.0),
- ('{"type": "fixed", "name": "Test", "size": 1}', b'B'),
- ('{"type": "fixed", "logicalType": "decimal", "name": "Test", "size": 8, "precision": 5, "scale": 4}',
- Decimal('3.1415')),
- ('{"type": "fixed", "logicalType": "decimal", "name": "Test", "size": 8, "precision": 5, "scale": 4}',
- Decimal('-3.1415')),
- ('{"type": "bytes", "logicalType": "decimal", "precision": 5, "scale": 4}', Decimal('3.1415')),
- ('{"type": "bytes", "logicalType": "decimal", "precision": 5, "scale": 4}', Decimal('-3.1415')),
- ('{"type": "enum", "name": "Test", "symbols": ["A", "B"]}', 'B'),
- ('{"type": "array", "items": "long"}', [1, 3, 2]),
- ('{"type": "map", "values": "long"}', {unicode('a'): 1,
- unicode('b'): 3,
- unicode('c'): 2}),
- ('["string", "null", "long"]', None),
- ('{"type": "int", "logicalType": "date"}', datetime.date(2000, 1, 1)),
- ('{"type": "int", "logicalType": "time-millis"}', datetime.time(23, 59, 59, 999000)),
- ('{"type": "int", "logicalType": "time-millis"}', datetime.time(0, 0, 0, 000000)),
- ('{"type": "long", "logicalType": "time-micros"}', datetime.time(23, 59, 59, 999999)),
- ('{"type": "long", "logicalType": "time-micros"}', datetime.time(0, 0, 0, 000000)),
- (
- '{"type": "long", "logicalType": "timestamp-millis"}',
- datetime.datetime(1000, 1, 1, 0, 0, 0, 000000, tzinfo=timezones.utc)
- ),
- (
- '{"type": "long", "logicalType": "timestamp-millis"}',
- datetime.datetime(9999, 12, 31, 23, 59, 59, 999000, tzinfo=timezones.utc)
- ),
- (
- '{"type": "long", "logicalType": "timestamp-millis"}',
- datetime.datetime(2000, 1, 18, 2, 2, 1, 100000, tzinfo=timezones.tst)
- ),
- (
- '{"type": "long", "logicalType": "timestamp-micros"}',
- datetime.datetime(1000, 1, 1, 0, 0, 0, 000000, tzinfo=timezones.utc)
- ),
- (
- '{"type": "long", "logicalType": "timestamp-micros"}',
- datetime.datetime(9999, 12, 31, 23, 59, 59, 999999, tzinfo=timezones.utc)
- ),
- (
- '{"type": "long", "logicalType": "timestamp-micros"}',
- datetime.datetime(2000, 1, 18, 2, 2, 1, 123499, tzinfo=timezones.tst)
- ),
- ('{"type": "string", "logicalType": "uuid"}', u'12345abcd'),
- ('{"type": "string", "logicalType": "unknown-logical-type"}', u'12345abcd'),
- ('{"type": "string", "logicalType": "timestamp-millis"}', u'12345abcd'),
- ("""\
+ ('"null"', None),
+ ('"boolean"', True),
+ ('"string"', unicode('adsfasdf09809dsf-=adsf')),
+ ('"bytes"', b'12345abcd'),
+ ('"int"', 1234),
+ ('"long"', 1234),
+ ('"float"', 1234.0),
+ ('"double"', 1234.0),
+ ('{"type": "fixed", "name": "Test", "size": 1}', b'B'),
+ ('{"type": "fixed", "logicalType": "decimal", "name": "Test", "size": 8, "precision": 5, "scale": 4}',
+ Decimal('3.1415')),
+ ('{"type": "fixed", "logicalType": "decimal", "name": "Test", "size": 8, "precision": 5, "scale": 4}',
+ Decimal('-3.1415')),
+ ('{"type": "bytes", "logicalType": "decimal", "precision": 5, "scale": 4}', Decimal('3.1415')),
+ ('{"type": "bytes", "logicalType": "decimal", "precision": 5, "scale": 4}', Decimal('-3.1415')),
+ ('{"type": "enum", "name": "Test", "symbols": ["A", "B"]}', 'B'),
+ ('{"type": "array", "items": "long"}', [1, 3, 2]),
+ ('{"type": "map", "values": "long"}', {unicode('a'): 1,
+ unicode('b'): 3,
+ unicode('c'): 2}),
+ ('["string", "null", "long"]', None),
+ ('{"type": "int", "logicalType": "date"}', datetime.date(2000, 1, 1)),
+ ('{"type": "int", "logicalType": "time-millis"}', datetime.time(23, 59, 59, 999000)),
+ ('{"type": "int", "logicalType": "time-millis"}', datetime.time(0, 0, 0, 000000)),
+ ('{"type": "long", "logicalType": "time-micros"}', datetime.time(23, 59, 59, 999999)),
+ ('{"type": "long", "logicalType": "time-micros"}', datetime.time(0, 0, 0, 000000)),
+ (
+ '{"type": "long", "logicalType": "timestamp-millis"}',
+ datetime.datetime(1000, 1, 1, 0, 0, 0, 000000, tzinfo=timezones.utc)
+ ),
+ (
+ '{"type": "long", "logicalType": "timestamp-millis"}',
+ datetime.datetime(9999, 12, 31, 23, 59, 59, 999000, tzinfo=timezones.utc)
+ ),
+ (
+ '{"type": "long", "logicalType": "timestamp-millis"}',
+ datetime.datetime(2000, 1, 18, 2, 2, 1, 100000, tzinfo=timezones.tst)
+ ),
+ (
+ '{"type": "long", "logicalType": "timestamp-micros"}',
+ datetime.datetime(1000, 1, 1, 0, 0, 0, 000000, tzinfo=timezones.utc)
+ ),
+ (
+ '{"type": "long", "logicalType": "timestamp-micros"}',
+ datetime.datetime(9999, 12, 31, 23, 59, 59, 999999, tzinfo=timezones.utc)
+ ),
+ (
+ '{"type": "long", "logicalType": "timestamp-micros"}',
+ datetime.datetime(2000, 1, 18, 2, 2, 1, 123499, tzinfo=timezones.tst)
+ ),
+ ('{"type": "string", "logicalType": "uuid"}', u'12345abcd'),
+ ('{"type": "string", "logicalType": "unknown-logical-type"}', u'12345abcd'),
+ ('{"type": "string", "logicalType": "timestamp-millis"}', u'12345abcd'),
+ ("""\
{"type": "record",
"name": "Test",
"fields": [{"name": "f", "type": "long"}]}
""", {'f': 5}),
- ("""\
+ ("""\
{"type": "record",
"name": "Lisp",
"fields": [{"name": "value",
@@ -105,34 +105,34 @@ SCHEMAS_TO_VALIDATE = (
)
BINARY_ENCODINGS = (
- (0, b'00'),
- (-1, b'01'),
- (1, b'02'),
- (-2, b'03'),
- (2, b'04'),
- (-64, b'7f'),
- (64, b'80 01'),
- (8192, b'80 80 01'),
- (-8193, b'81 80 01'),
+ (0, b'00'),
+ (-1, b'01'),
+ (1, b'02'),
+ (-2, b'03'),
+ (2, b'04'),
+ (-64, b'7f'),
+ (64, b'80 01'),
+ (8192, b'80 80 01'),
+ (-8193, b'81 80 01'),
)
DEFAULT_VALUE_EXAMPLES = (
- ('"null"', 'null', None),
- ('"boolean"', 'true', True),
- ('"string"', '"foo"', u'foo'),
- ('"bytes"', '"\u00FF\u00FF"', u'\xff\xff'),
- ('"int"', '5', 5),
- ('"long"', '5', 5),
- ('"float"', '1.1', 1.1),
- ('"double"', '1.1', 1.1),
- ('{"type": "fixed", "name": "F", "size": 2}', '"\u00FF\u00FF"', u'\xff\xff'),
- ('{"type": "enum", "name": "F", "symbols": ["FOO", "BAR"]}', '"FOO"', 'FOO'),
- ('{"type": "array", "items": "int"}', '[1, 2, 3]', [1, 2, 3]),
- ('{"type": "map", "values": "int"}', '{"a": 1, "b": 2}', {unicode('a'): 1,
- unicode('b'): 2}),
- ('["int", "null"]', '5', 5),
- ('{"type": "record", "name": "F", "fields": [{"name": "A", "type": "int"}]}',
- '{"A": 5}', {'A': 5}),
+ ('"null"', 'null', None),
+ ('"boolean"', 'true', True),
+ ('"string"', '"foo"', u'foo'),
+ ('"bytes"', '"\u00FF\u00FF"', u'\xff\xff'),
+ ('"int"', '5', 5),
+ ('"long"', '5', 5),
+ ('"float"', '1.1', 1.1),
+ ('"double"', '1.1', 1.1),
+ ('{"type": "fixed", "name": "F", "size": 2}', '"\u00FF\u00FF"', u'\xff\xff'),
+ ('{"type": "enum", "name": "F", "symbols": ["FOO", "BAR"]}', '"FOO"', 'FOO'),
+ ('{"type": "array", "items": "int"}', '[1, 2, 3]', [1, 2, 3]),
+ ('{"type": "map", "values": "int"}', '{"a": 1, "b": 2}', {unicode('a'): 1,
+ unicode('b'): 2}),
+ ('["int", "null"]', '5', 5),
+ ('{"type": "record", "name": "F", "fields": [{"name": "A", "type": "int"}]}',
+ '{"A": 5}', {'A': 5}),
)
LONG_RECORD_SCHEMA = schema.parse("""\
@@ -148,248 +148,261 @@ LONG_RECORD_SCHEMA = schema.parse("""\
LONG_RECORD_DATUM = {'A': 1, 'B': 2, 'C': 3, 'D': 4, 'E': 5, 'F': 6, 'G': 7}
+
def avro_hexlify(reader):
- """Return the hex value, as a string, of a binary-encoded int or long."""
- b = []
- current_byte = reader.read(1)
- b.append(hexlify(current_byte))
- while (ord(current_byte) & 0x80) != 0:
+ """Return the hex value, as a string, of a binary-encoded int or long."""
+ b = []
current_byte = reader.read(1)
b.append(hexlify(current_byte))
- return b' '.join(b)
+ while (ord(current_byte) & 0x80) != 0:
+ current_byte = reader.read(1)
+ b.append(hexlify(current_byte))
+ return b' '.join(b)
+
def print_test_name(test_name):
- print('')
- print(test_name)
- print('=' * len(test_name))
- print('')
+ print('')
+ print(test_name)
+ print('=' * len(test_name))
+ print('')
+
def write_datum(datum, writers_schema):
- writer = io.BytesIO()
- encoder = avro.io.BinaryEncoder(writer)
- datum_writer = avro.io.DatumWriter(writers_schema)
- datum_writer.write(datum, encoder)
- return writer, encoder, datum_writer
+ writer = io.BytesIO()
+ encoder = avro.io.BinaryEncoder(writer)
+ datum_writer = avro.io.DatumWriter(writers_schema)
+ datum_writer.write(datum, encoder)
+ return writer, encoder, datum_writer
+
def read_datum(buffer, writers_schema, readers_schema=None):
- reader = io.BytesIO(buffer.getvalue())
- decoder = avro.io.BinaryDecoder(reader)
- datum_reader = avro.io.DatumReader(writers_schema, readers_schema)
- return datum_reader.read(decoder)
+ reader = io.BytesIO(buffer.getvalue())
+ decoder = avro.io.BinaryDecoder(reader)
+ datum_reader = avro.io.DatumReader(writers_schema, readers_schema)
+ return datum_reader.read(decoder)
+
def check_binary_encoding(number_type):
- print_test_name('TEST BINARY %s ENCODING' % number_type.upper())
- correct = 0
- for datum, hex_encoding in BINARY_ENCODINGS:
- print('Datum: %d' % datum)
- print('Correct Encoding: %s' % hex_encoding)
-
- writers_schema = schema.parse('"%s"' % number_type.lower())
- writer, encoder, datum_writer = write_datum(datum, writers_schema)
- writer.seek(0)
- hex_val = avro_hexlify(writer)
-
- print('Read Encoding: %s' % hex_val)
- if hex_encoding == hex_val: correct += 1
- print('')
- return correct
+ print_test_name('TEST BINARY %s ENCODING' % number_type.upper())
+ correct = 0
+ for datum, hex_encoding in BINARY_ENCODINGS:
+ print('Datum: %d' % datum)
+ print('Correct Encoding: %s' % hex_encoding)
-def check_skip_number(number_type):
- print_test_name('TEST SKIP %s' % number_type.upper())
- correct = 0
- for value_to_skip, hex_encoding in BINARY_ENCODINGS:
- VALUE_TO_READ = 6253
- print('Value to Skip: %d' % value_to_skip)
-
- # write the value to skip and a known value
- writers_schema = schema.parse('"%s"' % number_type.lower())
- writer, encoder, datum_writer = write_datum(value_to_skip, writers_schema)
- datum_writer.write(VALUE_TO_READ, encoder)
-
- # skip the value
- reader = io.BytesIO(writer.getvalue())
- decoder = avro.io.BinaryDecoder(reader)
- decoder.skip_long()
+ writers_schema = schema.parse('"%s"' % number_type.lower())
+ writer, encoder, datum_writer = write_datum(datum, writers_schema)
+ writer.seek(0)
+ hex_val = avro_hexlify(writer)
- # read data from string buffer
- datum_reader = avro.io.DatumReader(writers_schema)
- read_value = datum_reader.read(decoder)
+ print('Read Encoding: %s' % hex_val)
+ if hex_encoding == hex_val:
+ correct += 1
+ print('')
+ return correct
- print('Read Value: %d' % read_value)
- if read_value == VALUE_TO_READ: correct += 1
- print('')
- return correct
-class TestIO(unittest.TestCase):
- #
- # BASIC FUNCTIONALITY
- #
-
- def test_validate(self):
- print_test_name('TEST VALIDATE')
- passed = 0
- for example_schema, datum in SCHEMAS_TO_VALIDATE:
- print('Schema: %s' % example_schema)
- print('Datum: %s' % datum)
- validated = avro.io.validate(schema.parse(example_schema), datum)
- print('Valid: %s' % validated)
- if validated: passed += 1
- self.assertEquals(passed, len(SCHEMAS_TO_VALIDATE))
-
- def test_round_trip(self):
- print_test_name('TEST ROUND TRIP')
+def check_skip_number(number_type):
+ print_test_name('TEST SKIP %s' % number_type.upper())
correct = 0
- for example_schema, datum in SCHEMAS_TO_VALIDATE:
- print('Schema: %s' % example_schema)
- print('Datum: %s' % datum)
-
- writers_schema = schema.parse(example_schema)
- writer, encoder, datum_writer = write_datum(datum, writers_schema)
- round_trip_datum = read_datum(writer, writers_schema)
-
- print('Round Trip Datum: %s' % round_trip_datum)
- if isinstance(round_trip_datum, Decimal):
- round_trip_datum = round_trip_datum.to_eng_string()
- datum = str(datum)
- elif isinstance(round_trip_datum, datetime.datetime):
- datum = datum.astimezone(tz=timezones.utc)
- if datum == round_trip_datum:
- correct += 1
- self.assertEquals(correct, len(SCHEMAS_TO_VALIDATE))
-
- #
- # BINARY ENCODING OF INT AND LONG
- #
-
- def test_binary_int_encoding(self):
- correct = check_binary_encoding('int')
- self.assertEquals(correct, len(BINARY_ENCODINGS))
-
- def test_binary_long_encoding(self):
- correct = check_binary_encoding('long')
- self.assertEquals(correct, len(BINARY_ENCODINGS))
-
- def test_skip_int(self):
- correct = check_skip_number('int')
- self.assertEquals(correct, len(BINARY_ENCODINGS))
-
- def test_skip_long(self):
- correct = check_skip_number('long')
- self.assertEquals(correct, len(BINARY_ENCODINGS))
-
- #
- # SCHEMA RESOLUTION
- #
-
- def test_schema_promotion(self):
- print_test_name('TEST SCHEMA PROMOTION')
- # note that checking writers_schema.type in read_data
- # allows us to handle promotion correctly
- promotable_schemas = ['"int"', '"long"', '"float"', '"double"']
- incorrect = 0
- for i, ws in enumerate(promotable_schemas):
- writers_schema = schema.parse(ws)
- datum_to_write = 219
- for rs in promotable_schemas[i + 1:]:
- readers_schema = schema.parse(rs)
- writer, enc, dw = write_datum(datum_to_write, writers_schema)
- datum_read = read_datum(writer, writers_schema, readers_schema)
- print('Writer: %s Reader: %s' % (writers_schema, readers_schema))
- print('Datum Read: %s' % datum_read)
- if datum_read != datum_to_write: incorrect += 1
- self.assertEquals(incorrect, 0)
+ for value_to_skip, hex_encoding in BINARY_ENCODINGS:
+ VALUE_TO_READ = 6253
+ print('Value to Skip: %d' % value_to_skip)
+
+ # write the value to skip and a known value
+ writers_schema = schema.parse('"%s"' % number_type.lower())
+ writer, encoder, datum_writer = write_datum(value_to_skip, writers_schema)
+ datum_writer.write(VALUE_TO_READ, encoder)
+
+ # skip the value
+ reader = io.BytesIO(writer.getvalue())
+ decoder = avro.io.BinaryDecoder(reader)
+ decoder.skip_long()
+
+ # read data from string buffer
+ datum_reader = avro.io.DatumReader(writers_schema)
+ read_value = datum_reader.read(decoder)
+
+ print('Read Value: %d' % read_value)
+ if read_value == VALUE_TO_READ:
+ correct += 1
+ print('')
+ return correct
- def test_unknown_symbol(self):
- print_test_name('TEST UNKNOWN SYMBOL')
- writers_schema = schema.parse("""\
+
+class TestIO(unittest.TestCase):
+ #
+ # BASIC FUNCTIONALITY
+ #
+
+ def test_validate(self):
+ print_test_name('TEST VALIDATE')
+ passed = 0
+ for example_schema, datum in SCHEMAS_TO_VALIDATE:
+ print('Schema: %s' % example_schema)
+ print('Datum: %s' % datum)
+ validated = avro.io.validate(schema.parse(example_schema), datum)
+ print('Valid: %s' % validated)
+ if validated:
+ passed += 1
+ self.assertEquals(passed, len(SCHEMAS_TO_VALIDATE))
+
+ def test_round_trip(self):
+ print_test_name('TEST ROUND TRIP')
+ correct = 0
+ for example_schema, datum in SCHEMAS_TO_VALIDATE:
+ print('Schema: %s' % example_schema)
+ print('Datum: %s' % datum)
+
+ writers_schema = schema.parse(example_schema)
+ writer, encoder, datum_writer = write_datum(datum, writers_schema)
+ round_trip_datum = read_datum(writer, writers_schema)
+
+ print('Round Trip Datum: %s' % round_trip_datum)
+ if isinstance(round_trip_datum, Decimal):
+ round_trip_datum = round_trip_datum.to_eng_string()
+ datum = str(datum)
+ elif isinstance(round_trip_datum, datetime.datetime):
+ datum = datum.astimezone(tz=timezones.utc)
+ if datum == round_trip_datum:
+ correct += 1
+ self.assertEquals(correct, len(SCHEMAS_TO_VALIDATE))
+
+ #
+ # BINARY ENCODING OF INT AND LONG
+ #
+
+ def test_binary_int_encoding(self):
+ correct = check_binary_encoding('int')
+ self.assertEquals(correct, len(BINARY_ENCODINGS))
+
+ def test_binary_long_encoding(self):
+ correct = check_binary_encoding('long')
+ self.assertEquals(correct, len(BINARY_ENCODINGS))
+
+ def test_skip_int(self):
+ correct = check_skip_number('int')
+ self.assertEquals(correct, len(BINARY_ENCODINGS))
+
+ def test_skip_long(self):
+ correct = check_skip_number('long')
+ self.assertEquals(correct, len(BINARY_ENCODINGS))
+
+ #
+ # SCHEMA RESOLUTION
+ #
+
+ def test_schema_promotion(self):
+ print_test_name('TEST SCHEMA PROMOTION')
+ # note that checking writers_schema.type in read_data
+ # allows us to handle promotion correctly
+ promotable_schemas = ['"int"', '"long"', '"float"', '"double"']
+ incorrect = 0
+ for i, ws in enumerate(promotable_schemas):
+ writers_schema = schema.parse(ws)
+ datum_to_write = 219
+ for rs in promotable_schemas[i + 1:]:
+ readers_schema = schema.parse(rs)
+ writer, enc, dw = write_datum(datum_to_write, writers_schema)
+ datum_read = read_datum(writer, writers_schema, readers_schema)
+ print('Writer: %s Reader: %s' % (writers_schema, readers_schema))
+ print('Datum Read: %s' % datum_read)
+ if datum_read != datum_to_write:
+ incorrect += 1
+ self.assertEquals(incorrect, 0)
+
+ def test_unknown_symbol(self):
+ print_test_name('TEST UNKNOWN SYMBOL')
+ writers_schema = schema.parse("""\
{"type": "enum", "name": "Test",
"symbols": ["FOO", "BAR"]}""")
- datum_to_write = 'FOO'
+ datum_to_write = 'FOO'
- readers_schema = schema.parse("""\
+ readers_schema = schema.parse("""\
{"type": "enum", "name": "Test",
"symbols": ["BAR", "BAZ"]}""")
- writer, encoder, datum_writer = write_datum(datum_to_write, writers_schema)
- reader = io.BytesIO(writer.getvalue())
- decoder = avro.io.BinaryDecoder(reader)
- datum_reader = avro.io.DatumReader(writers_schema, readers_schema)
- self.assertRaises(avro.io.SchemaResolutionException, datum_reader.read, decoder)
+ writer, encoder, datum_writer = write_datum(datum_to_write, writers_schema)
+ reader = io.BytesIO(writer.getvalue())
+ decoder = avro.io.BinaryDecoder(reader)
+ datum_reader = avro.io.DatumReader(writers_schema, readers_schema)
+ self.assertRaises(avro.io.SchemaResolutionException, datum_reader.read, decoder)
- def test_default_value(self):
- print_test_name('TEST DEFAULT VALUE')
- writers_schema = LONG_RECORD_SCHEMA
- datum_to_write = LONG_RECORD_DATUM
+ def test_default_value(self):
+ print_test_name('TEST DEFAULT VALUE')
+ writers_schema = LONG_RECORD_SCHEMA
+ datum_to_write = LONG_RECORD_DATUM
- correct = 0
- for field_type, default_json, default_datum in DEFAULT_VALUE_EXAMPLES:
- readers_schema = schema.parse("""\
+ correct = 0
+ for field_type, default_json, default_datum in DEFAULT_VALUE_EXAMPLES:
+ readers_schema = schema.parse("""\
{"type": "record", "name": "Test",
"fields": [{"name": "H", "type": %s, "default": %s}]}
""" % (field_type, default_json))
- datum_to_read = {'H': default_datum}
+ datum_to_read = {'H': default_datum}
- writer, encoder, datum_writer = write_datum(datum_to_write, writers_schema)
- datum_read = read_datum(writer, writers_schema, readers_schema)
- print('Datum Read: %s' % datum_read)
- if datum_to_read == datum_read: correct += 1
- self.assertEquals(correct, len(DEFAULT_VALUE_EXAMPLES))
+ writer, encoder, datum_writer = write_datum(datum_to_write, writers_schema)
+ datum_read = read_datum(writer, writers_schema, readers_schema)
+ print('Datum Read: %s' % datum_read)
+ if datum_to_read == datum_read:
+ correct += 1
+ self.assertEquals(correct, len(DEFAULT_VALUE_EXAMPLES))
- def test_no_default_value(self):
- print_test_name('TEST NO DEFAULT VALUE')
- writers_schema = LONG_RECORD_SCHEMA
- datum_to_write = LONG_RECORD_DATUM
+ def test_no_default_value(self):
+ print_test_name('TEST NO DEFAULT VALUE')
+ writers_schema = LONG_RECORD_SCHEMA
+ datum_to_write = LONG_RECORD_DATUM
- readers_schema = schema.parse("""\
+ readers_schema = schema.parse("""\
{"type": "record", "name": "Test",
"fields": [{"name": "H", "type": "int"}]}""")
- writer, encoder, datum_writer = write_datum(datum_to_write, writers_schema)
- reader = io.BytesIO(writer.getvalue())
- decoder = avro.io.BinaryDecoder(reader)
- datum_reader = avro.io.DatumReader(writers_schema, readers_schema)
- self.assertRaises(avro.io.SchemaResolutionException, datum_reader.read, decoder)
+ writer, encoder, datum_writer = write_datum(datum_to_write, writers_schema)
+ reader = io.BytesIO(writer.getvalue())
+ decoder = avro.io.BinaryDecoder(reader)
+ datum_reader = avro.io.DatumReader(writers_schema, readers_schema)
+ self.assertRaises(avro.io.SchemaResolutionException, datum_reader.read, decoder)
- def test_projection(self):
- print_test_name('TEST PROJECTION')
- writers_schema = LONG_RECORD_SCHEMA
- datum_to_write = LONG_RECORD_DATUM
+ def test_projection(self):
+ print_test_name('TEST PROJECTION')
+ writers_schema = LONG_RECORD_SCHEMA
+ datum_to_write = LONG_RECORD_DATUM
- readers_schema = schema.parse("""\
+ readers_schema = schema.parse("""\
{"type": "record", "name": "Test",
"fields": [{"name": "E", "type": "int"},
{"name": "F", "type": "int"}]}""")
- datum_to_read = {'E': 5, 'F': 6}
+ datum_to_read = {'E': 5, 'F': 6}
- writer, encoder, datum_writer = write_datum(datum_to_write, writers_schema)
- datum_read = read_datum(writer, writers_schema, readers_schema)
- print('Datum Read: %s' % datum_read)
- self.assertEquals(datum_to_read, datum_read)
+ writer, encoder, datum_writer = write_datum(datum_to_write, writers_schema)
+ datum_read = read_datum(writer, writers_schema, readers_schema)
+ print('Datum Read: %s' % datum_read)
+ self.assertEquals(datum_to_read, datum_read)
- def test_field_order(self):
- print_test_name('TEST FIELD ORDER')
- writers_schema = LONG_RECORD_SCHEMA
- datum_to_write = LONG_RECORD_DATUM
+ def test_field_order(self):
+ print_test_name('TEST FIELD ORDER')
+ writers_schema = LONG_RECORD_SCHEMA
+ datum_to_write = LONG_RECORD_DATUM
- readers_schema = schema.parse("""\
+ readers_schema = schema.parse("""\
{"type": "record", "name": "Test",
"fields": [{"name": "F", "type": "int"},
{"name": "E", "type": "int"}]}""")
- datum_to_read = {'E': 5, 'F': 6}
+ datum_to_read = {'E': 5, 'F': 6}
- writer, encoder, datum_writer = write_datum(datum_to_write, writers_schema)
- datum_read = read_datum(writer, writers_schema, readers_schema)
- print('Datum Read: %s' % datum_read)
- self.assertEquals(datum_to_read, datum_read)
+ writer, encoder, datum_writer = write_datum(datum_to_write, writers_schema)
+ datum_read = read_datum(writer, writers_schema, readers_schema)
+ print('Datum Read: %s' % datum_read)
+ self.assertEquals(datum_to_read, datum_read)
- def test_type_exception(self):
- print_test_name('TEST TYPE EXCEPTION')
- writers_schema = schema.parse("""\
+ def test_type_exception(self):
+ print_test_name('TEST TYPE EXCEPTION')
+ writers_schema = schema.parse("""\
{"type": "record", "name": "Test",
"fields": [{"name": "F", "type": "int"},
{"name": "E", "type": "int"}]}""")
- datum_to_write = {'E': 5, 'F': 'Bad'}
- self.assertRaises(avro.io.AvroTypeException, write_datum, datum_to_write, writers_schema)
+ datum_to_write = {'E': 5, 'F': 'Bad'}
+ self.assertRaises(avro.io.AvroTypeException, write_datum, datum_to_write, writers_schema)
+
if __name__ == '__main__':
- unittest.main()
+ unittest.main()
diff --git a/lang/py/avro/test/test_ipc.py b/lang/py/avro/test/test_ipc.py
index 6be7d05..1035617 100644
--- a/lang/py/avro/test/test_ipc.py
+++ b/lang/py/avro/test/test_ipc.py
@@ -32,15 +32,16 @@ from avro import ipc
class TestIPC(unittest.TestCase):
- def test_placeholder(self):
- pass
+ def test_placeholder(self):
+ pass
- def test_server_with_path(self):
- client_with_custom_path = ipc.HTTPTransceiver('apache.org', 80, '/service/article')
- self.assertEqual('/service/article', client_with_custom_path.req_resource)
+ def test_server_with_path(self):
+ client_with_custom_path = ipc.HTTPTransceiver('apache.org', 80, '/service/article')
+ self.assertEqual('/service/article', client_with_custom_path.req_resource)
+
+ client_with_default_path = ipc.HTTPTransceiver('apache.org', 80)
+ self.assertEqual('/', client_with_default_path.req_resource)
- client_with_default_path = ipc.HTTPTransceiver('apache.org', 80)
- self.assertEqual('/', client_with_default_path.req_resource)
if __name__ == '__main__':
- unittest.main()
+ unittest.main()
diff --git a/lang/py/avro/test/test_protocol.py b/lang/py/avro/test/test_protocol.py
index 323ccfc..f2b46e7 100644
--- a/lang/py/avro/test/test_protocol.py
+++ b/lang/py/avro/test/test_protocol.py
@@ -28,339 +28,342 @@ import avro.protocol
import avro.schema
try:
- unicode
+ unicode
except NameError:
- unicode = str
+ unicode = str
try:
- basestring # type: ignore
+ basestring # type: ignore
except NameError:
- basestring = (bytes, unicode)
+ basestring = (bytes, unicode)
class TestProtocol(object):
- """A proxy for a protocol string that provides useful test metadata."""
+ """A proxy for a protocol string that provides useful test metadata."""
- def __init__(self, data, name='', comment=''):
- if not isinstance(data, basestring):
- data = json.dumps(data)
- self.data = data
- self.name = name or data
- self.comment = comment
+ def __init__(self, data, name='', comment=''):
+ if not isinstance(data, basestring):
+ data = json.dumps(data)
+ self.data = data
+ self.name = name or data
+ self.comment = comment
- def parse(self):
- return avro.protocol.parse(str(self))
+ def parse(self):
+ return avro.protocol.parse(str(self))
- def __str__(self):
- return str(self.data)
+ def __str__(self):
+ return str(self.data)
class ValidTestProtocol(TestProtocol):
- """A proxy for a valid protocol string that provides useful test metadata."""
- valid = True
+ """A proxy for a valid protocol string that provides useful test metadata."""
+ valid = True
class InvalidTestProtocol(TestProtocol):
- """A proxy for an invalid protocol string that provides useful test metadata."""
- valid = False
+ """A proxy for an invalid protocol string that provides useful test metadata."""
+ valid = False
HELLO_WORLD = ValidTestProtocol({
- "namespace": "com.acme",
- "protocol": "HelloWorld",
- "types": [
- {"name": "Greeting", "type": "record", "fields": [
- {"name": "message", "type": "string"}]},
- {"name": "Curse", "type": "error", "fields": [
- {"name": "message", "type": "string"}]}
- ],
- "messages": {
- "hello": {
- "request": [{"name": "greeting", "type": "Greeting" }],
- "response": "Greeting",
- "errors": ["Curse"]
+ "namespace": "com.acme",
+ "protocol": "HelloWorld",
+ "types": [
+ {"name": "Greeting", "type": "record", "fields": [
+ {"name": "message", "type": "string"}]},
+ {"name": "Curse", "type": "error", "fields": [
+ {"name": "message", "type": "string"}]}
+ ],
+ "messages": {
+ "hello": {
+ "request": [{"name": "greeting", "type": "Greeting"}],
+ "response": "Greeting",
+ "errors": ["Curse"]
+ }
}
- }
})
EXAMPLES = [HELLO_WORLD, ValidTestProtocol({
"namespace": "org.apache.avro.test",
"protocol": "Simple",
"types": [
- {"name": "Kind", "type": "enum", "symbols": ["FOO","BAR","BAZ"]},
- {"name": "MD5", "type": "fixed", "size": 16},
- {"name": "TestRecord", "type": "record", "fields": [
- {"name": "name", "type": "string", "order": "ignore"},
- {"name": "kind", "type": "Kind", "order": "descending"},
- {"name": "hash", "type": "MD5"}
- ]},
- {"name": "TestError", "type": "error", "fields": [{"name": "message", "type": "string"}]}
+ {"name": "Kind", "type": "enum", "symbols": ["FOO", "BAR", "BAZ"]},
+ {"name": "MD5", "type": "fixed", "size": 16},
+ {"name": "TestRecord", "type": "record", "fields": [
+ {"name": "name", "type": "string", "order": "ignore"},
+ {"name": "kind", "type": "Kind", "order": "descending"},
+ {"name": "hash", "type": "MD5"}
+ ]},
+ {"name": "TestError", "type": "error", "fields": [{"name": "message", "type": "string"}]}
],
"messages": {
- "hello": {
- "request": [{"name": "greeting", "type": "string"}],
- "response": "string"
- }, "echo": {
- "request": [{"name": "record", "type": "TestRecord"}],
- "response": "TestRecord"
- }, "add": {
- "request": [{"name": "arg1", "type": "int"}, {"name": "arg2", "type": "int"}],
- "response": "int"
- }, "echoBytes": {
- "request": [{"name": "data", "type": "bytes"}],
- "response": "bytes"
- }, "error": {
- "request": [],
- "response": "null",
- "errors": ["TestError"]
- }
+ "hello": {
+ "request": [{"name": "greeting", "type": "string"}],
+ "response": "string"
+ }, "echo": {
+ "request": [{"name": "record", "type": "TestRecord"}],
+ "response": "TestRecord"
+ }, "add": {
+ "request": [{"name": "arg1", "type": "int"}, {"name": "arg2", "type": "int"}],
+ "response": "int"
+ }, "echoBytes": {
+ "request": [{"name": "data", "type": "bytes"}],
+ "response": "bytes"
+ }, "error": {
+ "request": [],
+ "response": "null",
+ "errors": ["TestError"]
+ }
}
- }), ValidTestProtocol({
+}), ValidTestProtocol({
"namespace": "org.apache.avro.test.namespace",
"protocol": "TestNamespace",
"types": [
- {"name": "org.apache.avro.test.util.MD5", "type": "fixed", "size": 16},
- {"name": "TestRecord", "type": "record", "fields": [
- {"name": "hash", "type": "org.apache.avro.test.util.MD5"}
- ]},
- {"name": "TestError", "namespace": "org.apache.avro.test.errors", "type": "error",
- "fields": [ {"name": "message", "type": "string"}]}
+ {"name": "org.apache.avro.test.util.MD5", "type": "fixed", "size": 16},
+ {"name": "TestRecord", "type": "record", "fields": [
+ {"name": "hash", "type": "org.apache.avro.test.util.MD5"}
+ ]},
+ {"name": "TestError", "namespace": "org.apache.avro.test.errors", "type": "error",
+ "fields": [{"name": "message", "type": "string"}]}
],
"messages": {
- "echo": {
- "request": [{"name": "record", "type": "TestRecord"}],
- "response": "TestRecord"
- }, "error": {
- "request": [],
- "response": "null",
- "errors": ["org.apache.avro.test.errors.TestError"]
- }
+ "echo": {
+ "request": [{"name": "record", "type": "TestRecord"}],
+ "response": "TestRecord"
+ }, "error": {
+ "request": [],
+ "response": "null",
+ "errors": ["org.apache.avro.test.errors.TestError"]
+ }
}
- }), ValidTestProtocol({
+}), ValidTestProtocol({
"namespace": "org.apache.avro.test.namespace",
"protocol": "TestImplicitNamespace",
"types": [
- {"name": "org.apache.avro.test.util.MD5", "type": "fixed", "size": 16},
- {"name": "ReferencedRecord", "type": "record",
- "fields": [ {"name": "foo", "type": "string"}]},
- {"name": "TestRecord", "type": "record",
- "fields": [{"name": "hash", "type": "org.apache.avro.test.util.MD5"},
- {"name": "unqualified", "type": "ReferencedRecord"}]
- },
- {"name": "TestError", "type": "error", "fields": [{"name": "message", "type": "string"}]}
+ {"name": "org.apache.avro.test.util.MD5", "type": "fixed", "size": 16},
+ {"name": "ReferencedRecord", "type": "record",
+ "fields": [{"name": "foo", "type": "string"}]},
+ {"name": "TestRecord", "type": "record",
+ "fields": [{"name": "hash", "type": "org.apache.avro.test.util.MD5"},
+ {"name": "unqualified", "type": "ReferencedRecord"}]
+ },
+ {"name": "TestError", "type": "error", "fields": [{"name": "message", "type": "string"}]}
],
"messages": {
- "echo": {
- "request": [{"name": "qualified", "type": "org.apache.avro.test.namespace.TestRecord"}],
- "response": "TestRecord"
- }, "error": {
- "request": [],
- "response": "null",
- "errors": ["org.apache.avro.test.namespace.TestError"]
- }
+ "echo": {
+ "request": [{"name": "qualified", "type": "org.apache.avro.test.namespace.TestRecord"}],
+ "response": "TestRecord"
+ }, "error": {
+ "request": [],
+ "response": "null",
+ "errors": ["org.apache.avro.test.namespace.TestError"]
+ }
}
- }), ValidTestProtocol({
+}), ValidTestProtocol({
"namespace": "org.apache.avro.test.namespace",
"protocol": "TestNamespaceTwo",
"types": [
- {"name": "org.apache.avro.test.util.MD5", "type": "fixed", "size": 16},
- {"name": "ReferencedRecord", "type": "record",
- "namespace": "org.apache.avro.other.namespace",
- "fields": [{"name": "foo", "type": "string"}]},
- {"name": "TestRecord", "type": "record",
- "fields": [{"name": "hash", "type": "org.apache.avro.test.util.MD5"},
- {"name": "qualified",
- "type": "org.apache.avro.other.namespace.ReferencedRecord"}]
- },
- {"name": "TestError",
- "type": "error", "fields": [{"name": "message", "type": "string"}]}],
+ {"name": "org.apache.avro.test.util.MD5", "type": "fixed", "size": 16},
+ {"name": "ReferencedRecord", "type": "record",
+ "namespace": "org.apache.avro.other.namespace",
+ "fields": [{"name": "foo", "type": "string"}]},
+ {"name": "TestRecord", "type": "record",
+ "fields": [{"name": "hash", "type": "org.apache.avro.test.util.MD5"},
+ {"name": "qualified",
+ "type": "org.apache.avro.other.namespace.ReferencedRecord"}]
+ },
+ {"name": "TestError",
+ "type": "error", "fields": [{"name": "message", "type": "string"}]}],
"messages": {
- "echo": {
- "request": [{"name": "qualified", "type": "org.apache.avro.test.namespace.TestRecord"}],
- "response": "TestRecord"
- }, "error": {
- "request": [],
- "response": "null",
- "errors": ["org.apache.avro.test.namespace.TestError"]
- }
+ "echo": {
+ "request": [{"name": "qualified", "type": "org.apache.avro.test.namespace.TestRecord"}],
+ "response": "TestRecord"
+ }, "error": {
+ "request": [],
+ "response": "null",
+ "errors": ["org.apache.avro.test.namespace.TestError"]
+ }
}
- }), ValidTestProtocol({
+}), ValidTestProtocol({
"namespace": "org.apache.avro.test.namespace",
"protocol": "TestValidRepeatedName",
"types": [
- {"name": "org.apache.avro.test.util.MD5", "type": "fixed", "size": 16},
- {"name": "ReferencedRecord", "type": "record",
- "namespace": "org.apache.avro.other.namespace",
- "fields": [{"name": "foo", "type": "string"}]},
- {"name": "ReferencedRecord", "type": "record",
- "fields": [{"name": "bar", "type": "double"}]},
- {"name": "TestError",
- "type": "error", "fields": [{"name": "message", "type": "string"}]}],
+ {"name": "org.apache.avro.test.util.MD5", "type": "fixed", "size": 16},
+ {"name": "ReferencedRecord", "type": "record",
+ "namespace": "org.apache.avro.other.namespace",
+ "fields": [{"name": "foo", "type": "string"}]},
+ {"name": "ReferencedRecord", "type": "record",
+ "fields": [{"name": "bar", "type": "double"}]},
+ {"name": "TestError",
+ "type": "error", "fields": [{"name": "message", "type": "string"}]}],
"messages": {
- "echo": {
- "request": [{"name": "qualified", "type": "ReferencedRecord"}],
- "response": "org.apache.avro.other.namespace.ReferencedRecord"},
- "error": {
- "request": [],
- "response": "null",
- "errors": ["org.apache.avro.test.namespace.TestError"]}
+ "echo": {
+ "request": [{"name": "qualified", "type": "ReferencedRecord"}],
+ "response": "org.apache.avro.other.namespace.ReferencedRecord"},
+ "error": {
+ "request": [],
+ "response": "null",
+ "errors": ["org.apache.avro.test.namespace.TestError"]}
}
- }), InvalidTestProtocol({
+}), InvalidTestProtocol({
"namespace": "org.apache.avro.test.namespace",
"protocol": "TestInvalidRepeatedName",
"types": [
- {"name": "org.apache.avro.test.util.MD5", "type": "fixed", "size": 16},
- {"name": "ReferencedRecord", "type": "record",
- "fields": [ {"name": "foo", "type": "string"}]},
- {"name": "ReferencedRecord", "type": "record",
- "fields": [ {"name": "bar", "type": "double"}]},
- {"name": "TestError",
- "type": "error", "fields": [{"name": "message", "type": "string"}]}],
- "messages": {
- "echo": {
- "request": [{"name": "qualified", "type": "ReferencedRecord"}],
- "response": "org.apache.avro.other.namespace.ReferencedRecord"
- }, "error": {
- "request": [],
- "response": "null",
- "errors": ["org.apache.avro.test.namespace.TestError"]
- }
- }
- }),
- ValidTestProtocol({
- "namespace": "org.apache.avro.test",
- "protocol": "BulkData",
- "types": [],
+ {"name": "org.apache.avro.test.util.MD5", "type": "fixed", "size": 16},
+ {"name": "ReferencedRecord", "type": "record",
+ "fields": [{"name": "foo", "type": "string"}]},
+ {"name": "ReferencedRecord", "type": "record",
+ "fields": [{"name": "bar", "type": "double"}]},
+ {"name": "TestError",
+ "type": "error", "fields": [{"name": "message", "type": "string"}]}],
"messages": {
- "read": {
- "request": [],
- "response": "bytes"
- }, "write": {
- "request": [ {"name": "data", "type": "bytes"} ],
- "response": "null"
- }
+ "echo": {
+ "request": [{"name": "qualified", "type": "ReferencedRecord"}],
+ "response": "org.apache.avro.other.namespace.ReferencedRecord"
+ }, "error": {
+ "request": [],
+ "response": "null",
+ "errors": ["org.apache.avro.test.namespace.TestError"]
+ }
}
- }), ValidTestProtocol({
- "protocol": "API",
- "namespace": "xyz.api",
- "types": [{
- "type": "enum",
- "name": "Symbology",
- "namespace": "xyz.api.product",
- "symbols": ["OPRA", "CUSIP", "ISIN", "SEDOL"]
- }, {
- "type": "record",
- "name": "Symbol",
- "namespace": "xyz.api.product",
- "fields": [{"name": "symbology", "type": "xyz.api.product.Symbology"},
- {"name": "symbol", "type": "string"}]
- }, {
- "type": "record",
- "name": "MultiSymbol",
- "namespace": "xyz.api.product",
- "fields": [{"name": "symbols",
- "type": {"type": "map", "values": "xyz.api.product.Symbol"}}]
- }],
- "messages": {}
- }),
+}),
+ ValidTestProtocol({
+ "namespace": "org.apache.avro.test",
+ "protocol": "BulkData",
+ "types": [],
+ "messages": {
+ "read": {
+ "request": [],
+ "response": "bytes"
+ }, "write": {
+ "request": [{"name": "data", "type": "bytes"}],
+ "response": "null"
+ }
+ }
+ }), ValidTestProtocol({
+ "protocol": "API",
+ "namespace": "xyz.api",
+ "types": [{
+ "type": "enum",
+ "name": "Symbology",
+ "namespace": "xyz.api.product",
+ "symbols": ["OPRA", "CUSIP", "ISIN", "SEDOL"]
+ }, {
+ "type": "record",
+ "name": "Symbol",
+ "namespace": "xyz.api.product",
+ "fields": [{"name": "symbology", "type": "xyz.api.product.Symbology"},
+ {"name": "symbol", "type": "string"}]
+ }, {
+ "type": "record",
+ "name": "MultiSymbol",
+ "namespace": "xyz.api.product",
+ "fields": [{"name": "symbols",
+ "type": {"type": "map", "values": "xyz.api.product.Symbol"}}]
+ }],
+ "messages": {}
+ }),
]
VALID_EXAMPLES = [e for e in EXAMPLES if getattr(e, "valid", False)]
class TestMisc(unittest.TestCase):
- def test_inner_namespace_set(self):
- print('')
- print('TEST INNER NAMESPACE')
- print('===================')
- print('')
- proto = HELLO_WORLD.parse()
- self.assertEqual(proto.namespace, "com.acme")
- self.assertEqual(proto.fullname, "com.acme.HelloWorld")
- greeting_type = proto.types_dict['Greeting']
- self.assertEqual(greeting_type.namespace, 'com.acme')
-
- def test_inner_namespace_not_rendered(self):
- proto = HELLO_WORLD.parse()
- self.assertEqual('com.acme.Greeting', proto.types[0].fullname)
- self.assertEqual('Greeting', proto.types[0].name)
- # but there shouldn't be 'namespace' rendered to json on the inner type
- self.assertFalse('namespace' in proto.to_json()['types'][0])
+ def test_inner_namespace_set(self):
+ print('')
+ print('TEST INNER NAMESPACE')
+ print('===================')
+ print('')
+ proto = HELLO_WORLD.parse()
+ self.assertEqual(proto.namespace, "com.acme")
+ self.assertEqual(proto.fullname, "com.acme.HelloWorld")
+ greeting_type = proto.types_dict['Greeting']
+ self.assertEqual(greeting_type.namespace, 'com.acme')
+
+ def test_inner_namespace_not_rendered(self):
+ proto = HELLO_WORLD.parse()
+ self.assertEqual('com.acme.Greeting', proto.types[0].fullname)
+ self.assertEqual('Greeting', proto.types[0].name)
+ # but there shouldn't be 'namespace' rendered to json on the inner type
+ self.assertFalse('namespace' in proto.to_json()['types'][0])
class ProtocolParseTestCase(unittest.TestCase):
- """Enable generating parse test cases over all the valid and invalid example protocols."""
-
- def __init__(self, test_proto):
- """Ignore the normal signature for unittest.TestCase because we are generating
- many test cases from this one class. This is safe as long as the autoloader
- ignores this class. The autoloader will ignore this class as long as it has
- no methods starting with `test_`.
- """
- super(ProtocolParseTestCase, self).__init__(
- 'parse_valid' if test_proto.valid else 'parse_invalid')
- self.test_proto = test_proto
-
- def parse_valid(self):
- """Parsing a valid protocol should not error."""
- try:
- self.test_proto.parse()
- except avro.protocol.ProtocolParseException:
- self.fail("Valid protocol failed to parse: {!s}".format(self.test_proto))
-
- def parse_invalid(self):
- """Parsing an invalid schema should error."""
- try:
- self.test_proto.parse()
- except (avro.protocol.ProtocolParseException, avro.schema.SchemaParseException):
- pass
- else:
- self.fail("Invalid protocol should not have parsed: {!s}".format(self.test_proto))
+ """Enable generating parse test cases over all the valid and invalid example protocols."""
+
+ def __init__(self, test_proto):
+ """Ignore the normal signature for unittest.TestCase because we are generating
+ many test cases from this one class. This is safe as long as the autoloader
+ ignores this class. The autoloader will ignore this class as long as it has
+ no methods starting with `test_`.
+ """
+ super(ProtocolParseTestCase, self).__init__(
+ 'parse_valid' if test_proto.valid else 'parse_invalid')
+ self.test_proto = test_proto
+
+ def parse_valid(self):
+ """Parsing a valid protocol should not error."""
+ try:
+ self.test_proto.parse()
+ except avro.protocol.ProtocolParseException:
+ self.fail("Valid protocol failed to parse: {!s}".format(self.test_proto))
+
+ def parse_invalid(self):
+ """Parsing an invalid schema should error."""
+ try:
+ self.test_proto.parse()
+ except (avro.protocol.ProtocolParseException, avro.schema.SchemaParseException):
+ pass
+ else:
+ self.fail("Invalid protocol should not have parsed: {!s}".format(self.test_proto))
+
class ErrorSchemaTestCase(unittest.TestCase):
- """Enable generating error schema test cases across all the valid test protocols."""
-
- def __init__(self, test_proto):
- """Ignore the normal signature for unittest.TestCase because we are generating
- many test cases from this one class. This is safe as long as the autoloader
- ignores this class. The autoloader will ignore this class as long as it has
- no methods starting with `test_`.
- """
- super(ErrorSchemaTestCase, self).__init__('check_error_schema_exists')
- self.test_proto = test_proto
-
- def check_error_schema_exists(self):
- """Protocol messages should always have at least a string error schema."""
- p = self.test_proto.parse()
- for k, m in p.messages.items():
- self.assertIsNotNone(m.errors, "Message {} did not have the expected implicit "
- "string error schema.".format(k))
+ """Enable generating error schema test cases across all the valid test protocols."""
+
+ def __init__(self, test_proto):
+ """Ignore the normal signature for unittest.TestCase because we are generating
+ many test cases from this one class. This is safe as long as the autoloader
+ ignores this class. The autoloader will ignore this class as long as it has
+ no methods starting with `test_`.
+ """
+ super(ErrorSchemaTestCase, self).__init__('check_error_schema_exists')
+ self.test_proto = test_proto
+
+ def check_error_schema_exists(self):
+ """Protocol messages should always have at least a string error schema."""
+ p = self.test_proto.parse()
+ for k, m in p.messages.items():
+ self.assertIsNotNone(m.errors, "Message {} did not have the expected implicit "
+ "string error schema.".format(k))
+
class RoundTripParseTestCase(unittest.TestCase):
- """Enable generating round-trip parse test cases over all the valid test protocols."""
+ """Enable generating round-trip parse test cases over all the valid test protocols."""
- def __init__(self, test_proto):
- """Ignore the normal signature for unittest.TestCase because we are generating
- many test cases from this one class. This is safe as long as the autoloader
- ignores this class. The autoloader will ignore this class as long as it has
- no methods starting with `test_`.
- """
- super(RoundTripParseTestCase, self).__init__('parse_round_trip')
- self.test_proto = test_proto
+ def __init__(self, test_proto):
+ """Ignore the normal signature for unittest.TestCase because we are generating
+ many test cases from this one class. This is safe as long as the autoloader
+ ignores this class. The autoloader will ignore this class as long as it has
+ no methods starting with `test_`.
+ """
+ super(RoundTripParseTestCase, self).__init__('parse_round_trip')
+ self.test_proto = test_proto
- def parse_round_trip(self):
- """The string of a Schema should be parseable to the same Schema."""
- parsed = self.test_proto.parse()
- round_trip = avro.protocol.parse(str(parsed))
- self.assertEqual(parsed, round_trip)
+ def parse_round_trip(self):
+ """The string of a Schema should be parseable to the same Schema."""
+ parsed = self.test_proto.parse()
+ round_trip = avro.protocol.parse(str(parsed))
+ self.assertEqual(parsed, round_trip)
def load_tests(loader, default_tests, pattern):
- """Generate test cases across many test schema."""
- suite = unittest.TestSuite()
- suite.addTests(loader.loadTestsFromTestCase(TestMisc))
- suite.addTests(ProtocolParseTestCase(ex) for ex in EXAMPLES)
- suite.addTests(RoundTripParseTestCase(ex) for ex in VALID_EXAMPLES)
- return suite
+ """Generate test cases across many test schema."""
+ suite = unittest.TestSuite()
+ suite.addTests(loader.loadTestsFromTestCase(TestMisc))
+ suite.addTests(ProtocolParseTestCase(ex) for ex in EXAMPLES)
+ suite.addTests(RoundTripParseTestCase(ex) for ex in VALID_EXAMPLES)
+ return suite
+
if __name__ == '__main__':
- unittest.main()
+ unittest.main()
diff --git a/lang/py/avro/test/test_schema.py b/lang/py/avro/test/test_schema.py
index 38e8360..71d4c5e 100644
--- a/lang/py/avro/test/test_schema.py
+++ b/lang/py/avro/test/test_schema.py
@@ -28,47 +28,47 @@ import warnings
from avro import schema
try:
- unicode
+ unicode
except NameError:
- unicode = str
+ unicode = str
try:
- basestring # type: ignore
+ basestring # type: ignore
except NameError:
- basestring = (bytes, unicode)
+ basestring = (bytes, unicode)
try:
- from typing import List
+ from typing import List
except ImportError:
- pass
+ pass
class TestSchema(object):
- """A proxy for a schema string that provides useful test metadata."""
+ """A proxy for a schema string that provides useful test metadata."""
- def __init__(self, data, name='', comment='', warnings=None):
- if not isinstance(data, basestring):
- data = json.dumps(data)
- self.data = data
- self.name = name or data # default to data for name
- self.comment = comment
- self.warnings = warnings
+ def __init__(self, data, name='', comment='', warnings=None):
+ if not isinstance(data, basestring):
+ data = json.dumps(data)
+ self.data = data
+ self.name = name or data # default to data for name
+ self.comment = comment
+ self.warnings = warnings
- def parse(self):
- return schema.parse(str(self))
+ def parse(self):
+ return schema.parse(str(self))
- def __str__(self):
- return str(self.data)
+ def __str__(self):
+ return str(self.data)
class ValidTestSchema(TestSchema):
- """A proxy for a valid schema string that provides useful test metadata."""
- valid = True
+ """A proxy for a valid schema string that provides useful test metadata."""
+ valid = True
class InvalidTestSchema(TestSchema):
- """A proxy for an invalid schema string that provides useful test metadata."""
- valid = False
+ """A proxy for an invalid schema string that provides useful test metadata."""
+ valid = False
PRIMITIVE_EXAMPLES = [InvalidTestSchema('"True"')] # type: List[TestSchema]
@@ -79,227 +79,227 @@ PRIMITIVE_EXAMPLES.extend([ValidTestSchema('"{}"'.format(t)) for t in schema.PRI
PRIMITIVE_EXAMPLES.extend([ValidTestSchema({"type": t}) for t in schema.PRIMITIVE_TYPES])
FIXED_EXAMPLES = [
- ValidTestSchema({"type": "fixed", "name": "Test", "size": 1}),
- ValidTestSchema({"type": "fixed", "name": "MyFixed", "size": 1,
- "namespace": "org.apache.hadoop.avro"}),
- InvalidTestSchema({"type": "fixed", "name": "Missing size"}),
- InvalidTestSchema({"type": "fixed", "size": 314}),
- InvalidTestSchema({"type": "fixed", "size": 314, "name": "dr. spaceman"}, comment='AVRO-621'),
+ ValidTestSchema({"type": "fixed", "name": "Test", "size": 1}),
+ ValidTestSchema({"type": "fixed", "name": "MyFixed", "size": 1,
+ "namespace": "org.apache.hadoop.avro"}),
+ InvalidTestSchema({"type": "fixed", "name": "Missing size"}),
+ InvalidTestSchema({"type": "fixed", "size": 314}),
+ InvalidTestSchema({"type": "fixed", "size": 314, "name": "dr. spaceman"}, comment='AVRO-621'),
]
ENUM_EXAMPLES = [
- ValidTestSchema({"type": "enum", "name": "Test", "symbols": ["A", "B"]}),
- ValidTestSchema({"type": "enum", "name": "AVRO2174", "symbols": ["nowhitespace"]}),
- InvalidTestSchema({"type": "enum", "name": "Status", "symbols": "Normal Caution Critical"}),
- InvalidTestSchema({"type": "enum", "name": [0, 1, 1, 2, 3, 5, 8],
- "symbols": ["Golden", "Mean"]}),
- InvalidTestSchema({"type": "enum", "symbols" : ["I", "will", "fail", "no", "name"]}),
- InvalidTestSchema({"type": "enum", "name": "Test", "symbols": ["AA", "AA"]}),
- InvalidTestSchema({"type": "enum", "name": "AVRO2174", "symbols": ["white space"]}),
+ ValidTestSchema({"type": "enum", "name": "Test", "symbols": ["A", "B"]}),
+ ValidTestSchema({"type": "enum", "name": "AVRO2174", "symbols": ["nowhitespace"]}),
+ InvalidTestSchema({"type": "enum", "name": "Status", "symbols": "Normal Caution Critical"}),
+ InvalidTestSchema({"type": "enum", "name": [0, 1, 1, 2, 3, 5, 8],
+ "symbols": ["Golden", "Mean"]}),
+ InvalidTestSchema({"type": "enum", "symbols": ["I", "will", "fail", "no", "name"]}),
+ InvalidTestSchema({"type": "enum", "name": "Test", "symbols": ["AA", "AA"]}),
+ InvalidTestSchema({"type": "enum", "name": "AVRO2174", "symbols": ["white space"]}),
]
ARRAY_EXAMPLES = [
- ValidTestSchema({"type": "array", "items": "long"}),
- ValidTestSchema({"type": "array",
- "items": {"type": "enum", "name": "Test", "symbols": ["A", "B"]}}),
+ ValidTestSchema({"type": "array", "items": "long"}),
+ ValidTestSchema({"type": "array",
+ "items": {"type": "enum", "name": "Test", "symbols": ["A", "B"]}}),
]
MAP_EXAMPLES = [
- ValidTestSchema({"type": "map", "values": "long"}),
- ValidTestSchema({"type": "map",
- "values": {"type": "enum", "name": "Test", "symbols": ["A", "B"]}}),
+ ValidTestSchema({"type": "map", "values": "long"}),
+ ValidTestSchema({"type": "map",
+ "values": {"type": "enum", "name": "Test", "symbols": ["A", "B"]}}),
]
UNION_EXAMPLES = [
- ValidTestSchema(["string", "null", "long"]),
- InvalidTestSchema(["null", "null"]),
- InvalidTestSchema(["long", "long"]),
- InvalidTestSchema([{"type": "array", "items": "long"},
- {"type": "array", "items": "string"}]),
+ ValidTestSchema(["string", "null", "long"]),
+ InvalidTestSchema(["null", "null"]),
+ InvalidTestSchema(["long", "long"]),
+ InvalidTestSchema([{"type": "array", "items": "long"},
+ {"type": "array", "items": "string"}]),
]
RECORD_EXAMPLES = [
- ValidTestSchema({"type": "record", "name": "Test", "fields": [{"name": "f", "type": "long"}]}),
- ValidTestSchema({"type": "error", "name": "Test", "fields": [{"name": "f", "type": "long"}]}),
- ValidTestSchema({"type": "record", "name": "Node",
- "fields": [
- {"name": "label", "type": "string"},
- {"name": "children", "type": {"type": "array", "items": "Node"}}]}),
- ValidTestSchema({"type": "record", "name": "Lisp",
- "fields": [{"name": "value",
- "type": ["null", "string",
- {"type": "record", "name": "Cons",
- "fields": [{"name": "car", "type": "Lisp"},
- {"name": "cdr", "type": "Lisp"}]}]}]}),
- ValidTestSchema({"type": "record", "name": "HandshakeRequest",
- "namespace": "org.apache.avro.ipc",
- "fields": [{"name": "clientHash",
- "type": {"type": "fixed", "name": "MD5", "size": 16}},
- {"name": "clientProtocol", "type": ["null", "string"]},
- {"name": "serverHash", "type": "MD5"},
- {"name": "meta",
- "type": ["null", {"type": "map", "values": "bytes"}]}]}),
- ValidTestSchema({"type": "record", "name": "HandshakeResponse",
- "namespace": "org.apache.avro.ipc",
- "fields": [{"name": "match",
- "type": {"type": "enum", "name": "HandshakeMatch",
- "symbols": ["BOTH", "CLIENT", "NONE"]}},
- {"name": "serverProtocol", "type": ["null", "string"]},
- {"name": "serverHash",
- "type": ["null", {"name": "MD5", "size": 16, "type": "fixed"}]},
- {"name": "meta",
- "type": ["null", {"type": "map", "values": "bytes"}]}]}),
- ValidTestSchema({"type": "record",
- "name": "Interop",
- "namespace": "org.apache.avro",
- "fields": [{"name": "intField", "type": "int"},
- {"name": "longField", "type": "long"},
- {"name": "stringField", "type": "string"},
- {"name": "boolField", "type": "boolean"},
- {"name": "floatField", "type": "float"},
- {"name": "doubleField", "type": "double"},
- {"name": "bytesField", "type": "bytes"},
- {"name": "nullField", "type": "null"},
- {"name": "arrayField", "type": {"type": "array", "items": "double"}},
- {"name": "mapField",
- "type": {"type": "map",
- "values": {"name": "Foo",
- "type": "record",
- "fields": [{"name": "label", "type": "string"}]}}},
- {"name": "unionField",
- "type": ["boolean", "double", {"type": "array", "items": "bytes"}]},
- {"name": "enumField",
- "type": {"type": "enum", "name": "Kind", "symbols": ["A", "B", "C"]}},
- {"name": "fixedField",
- "type": {"type": "fixed", "name": "MD5", "size": 16}},
- {"name": "recordField",
- "type": {"type": "record", "name": "Node",
- "fields": [{"name": "label", "type": "string"},
- {"name": "children",
- "type": {"type": "array",
- "items": "Node"}}]}}]}),
- ValidTestSchema({"type": "record", "name": "ipAddr",
- "fields": [{"name": "addr", "type": [{"name": "IPv6", "type": "fixed", "size": 16},
- {"name": "IPv4", "type": "fixed", "size": 4}]}]}),
- InvalidTestSchema({"type": "record", "name": "Address",
- "fields": [{"type": "string"}, {"type": "string", "name": "City"}]}),
- InvalidTestSchema({"type": "record", "name": "Event",
- "fields": [{"name": "Sponsor"}, {"name": "City", "type": "string"}]}),
- InvalidTestSchema({"type": "record", "name": "Rainer",
- "fields": "His vision, from the constantly passing bars"}),
- InvalidTestSchema({"name": ["Tom", "Jerry"], "type": "record",
- "fields": [{"name": "name", "type": "string"}]}),
+ ValidTestSchema({"type": "record", "name": "Test", "fields": [{"name": "f", "type": "long"}]}),
+ ValidTestSchema({"type": "error", "name": "Test", "fields": [{"name": "f", "type": "long"}]}),
+ ValidTestSchema({"type": "record", "name": "Node",
+ "fields": [
+ {"name": "label", "type": "string"},
+ {"name": "children", "type": {"type": "array", "items": "Node"}}]}),
+ ValidTestSchema({"type": "record", "name": "Lisp",
+ "fields": [{"name": "value",
+ "type": ["null", "string",
+ {"type": "record", "name": "Cons",
+ "fields": [{"name": "car", "type": "Lisp"},
+ {"name": "cdr", "type": "Lisp"}]}]}]}),
+ ValidTestSchema({"type": "record", "name": "HandshakeRequest",
+ "namespace": "org.apache.avro.ipc",
+ "fields": [{"name": "clientHash",
+ "type": {"type": "fixed", "name": "MD5", "size": 16}},
+ {"name": "clientProtocol", "type": ["null", "string"]},
+ {"name": "serverHash", "type": "MD5"},
+ {"name": "meta",
+ "type": ["null", {"type": "map", "values": "bytes"}]}]}),
+ ValidTestSchema({"type": "record", "name": "HandshakeResponse",
+ "namespace": "org.apache.avro.ipc",
+ "fields": [{"name": "match",
+ "type": {"type": "enum", "name": "HandshakeMatch",
+ "symbols": ["BOTH", "CLIENT", "NONE"]}},
+ {"name": "serverProtocol", "type": ["null", "string"]},
+ {"name": "serverHash",
+ "type": ["null", {"name": "MD5", "size": 16, "type": "fixed"}]},
+ {"name": "meta",
+ "type": ["null", {"type": "map", "values": "bytes"}]}]}),
+ ValidTestSchema({"type": "record",
+ "name": "Interop",
+ "namespace": "org.apache.avro",
+ "fields": [{"name": "intField", "type": "int"},
+ {"name": "longField", "type": "long"},
+ {"name": "stringField", "type": "string"},
+ {"name": "boolField", "type": "boolean"},
+ {"name": "floatField", "type": "float"},
+ {"name": "doubleField", "type": "double"},
+ {"name": "bytesField", "type": "bytes"},
+ {"name": "nullField", "type": "null"},
+ {"name": "arrayField", "type": {"type": "array", "items": "double"}},
+ {"name": "mapField",
+ "type": {"type": "map",
+ "values": {"name": "Foo",
+ "type": "record",
+ "fields": [{"name": "label", "type": "string"}]}}},
+ {"name": "unionField",
+ "type": ["boolean", "double", {"type": "array", "items": "bytes"}]},
+ {"name": "enumField",
+ "type": {"type": "enum", "name": "Kind", "symbols": ["A", "B", "C"]}},
+ {"name": "fixedField",
+ "type": {"type": "fixed", "name": "MD5", "size": 16}},
+ {"name": "recordField",
+ "type": {"type": "record", "name": "Node",
+ "fields": [{"name": "label", "type": "string"},
+ {"name": "children",
+ "type": {"type": "array",
+ "items": "Node"}}]}}]}),
+ ValidTestSchema({"type": "record", "name": "ipAddr",
+ "fields": [{"name": "addr", "type": [{"name": "IPv6", "type": "fixed", "size": 16},
+ {"name": "IPv4", "type": "fixed", "size": 4}]}]}),
+ InvalidTestSchema({"type": "record", "name": "Address",
+ "fields": [{"type": "string"}, {"type": "string", "name": "City"}]}),
+ InvalidTestSchema({"type": "record", "name": "Event",
+ "fields": [{"name": "Sponsor"}, {"name": "City", "type": "string"}]}),
+ InvalidTestSchema({"type": "record", "name": "Rainer",
+ "fields": "His vision, from the constantly passing bars"}),
+ InvalidTestSchema({"name": ["Tom", "Jerry"], "type": "record",
+ "fields": [{"name": "name", "type": "string"}]}),
]
DOC_EXAMPLES = [
- ValidTestSchema({"type": "record", "name": "TestDoc", "doc": "Doc string",
- "fields": [{"name": "name", "type": "string", "doc" : "Doc String"}]}),
- ValidTestSchema({"type": "enum", "name": "Test", "symbols": ["A", "B"], "doc": "Doc String"}),
+ ValidTestSchema({"type": "record", "name": "TestDoc", "doc": "Doc string",
+ "fields": [{"name": "name", "type": "string", "doc": "Doc String"}]}),
+ ValidTestSchema({"type": "enum", "name": "Test", "symbols": ["A", "B"], "doc": "Doc String"}),
]
OTHER_PROP_EXAMPLES = [
- ValidTestSchema({"type": "record", "name": "TestRecord", "cp_string": "string",
- "cp_int": 1, "cp_array": [1, 2, 3, 4],
- "fields": [{"name": "f1", "type": "string", "cp_object": {"a": 1,"b": 2}},
- {"name": "f2", "type": "long", "cp_null": None}]}),
- ValidTestSchema({"type": "map", "values": "long", "cp_boolean": True}),
- ValidTestSchema({"type": "enum", "name": "TestEnum",
- "symbols": ["one", "two", "three"], "cp_float": 1.0}),
+ ValidTestSchema({"type": "record", "name": "TestRecord", "cp_string": "string",
+ "cp_int": 1, "cp_array": [1, 2, 3, 4],
+ "fields": [{"name": "f1", "type": "string", "cp_object": {"a": 1, "b": 2}},
+ {"name": "f2", "type": "long", "cp_null": None}]}),
+ ValidTestSchema({"type": "map", "values": "long", "cp_boolean": True}),
+ ValidTestSchema({"type": "enum", "name": "TestEnum",
+ "symbols": ["one", "two", "three"], "cp_float": 1.0}),
]
DECIMAL_LOGICAL_TYPE = [
- ValidTestSchema({"type": "fixed", "logicalType": "decimal", "name": "TestDecimal", "precision": 4, "size": 10, "scale": 2}),
- ValidTestSchema({"type": "bytes", "logicalType": "decimal", "precision": 4, "scale": 2}),
- InvalidTestSchema({"type": "fixed", "logicalType": "decimal", "name": "TestDecimal2", "precision": 2, "scale": 2, "size": -2}),
+ ValidTestSchema({"type": "fixed", "logicalType": "decimal", "name": "TestDecimal", "precision": 4, "size": 10, "scale": 2}),
+ ValidTestSchema({"type": "bytes", "logicalType": "decimal", "precision": 4, "scale": 2}),
+ InvalidTestSchema({"type": "fixed", "logicalType": "decimal", "name": "TestDecimal2", "precision": 2, "scale": 2, "size": -2}),
]
DATE_LOGICAL_TYPE = [
- ValidTestSchema({"type": "int", "logicalType": "date"})
+ ValidTestSchema({"type": "int", "logicalType": "date"})
]
TIMEMILLIS_LOGICAL_TYPE = [
- ValidTestSchema({"type": "int", "logicalType": "time-millis"})
+ ValidTestSchema({"type": "int", "logicalType": "time-millis"})
]
TIMEMICROS_LOGICAL_TYPE = [
- ValidTestSchema({"type": "long", "logicalType": "time-micros"})
+ ValidTestSchema({"type": "long", "logicalType": "time-micros"})
]
TIMESTAMPMILLIS_LOGICAL_TYPE = [
- ValidTestSchema({"type": "long", "logicalType": "timestamp-millis"})
+ ValidTestSchema({"type": "long", "logicalType": "timestamp-millis"})
]
TIMESTAMPMICROS_LOGICAL_TYPE = [
- ValidTestSchema({"type": "long", "logicalType": "timestamp-micros"})
+ ValidTestSchema({"type": "long", "logicalType": "timestamp-micros"})
]
IGNORED_LOGICAL_TYPE = [
- ValidTestSchema(
- {"type": "string", "logicalType": "uuid"},
- warnings=[schema.IgnoredLogicalType('Unknown uuid, using string.')]),
- ValidTestSchema(
- {"type": "string", "logicalType": "unknown-logical-type"},
- warnings=[schema.IgnoredLogicalType('Unknown unknown-logical-type, using string.')]),
- ValidTestSchema(
- {"type": "bytes", "logicalType": "decimal", "scale": 0},
- warnings=[schema.IgnoredLogicalType('Invalid decimal precision None. Must be a positive integer.')]),
- ValidTestSchema(
- {"type": "bytes", "logicalType": "decimal", "precision": 2.4, "scale": 0},
- warnings=[schema.IgnoredLogicalType('Invalid decimal precision 2.4. Must be a positive integer.')]),
- ValidTestSchema(
- {"type": "bytes", "logicalType": "decimal", "precision": 2, "scale": -2},
- warnings=[schema.IgnoredLogicalType('Invalid decimal scale -2. Must be a positive integer.')]),
- ValidTestSchema(
- {"type": "bytes", "logicalType": "decimal", "precision": -2, "scale": 2},
- warnings=[schema.IgnoredLogicalType('Invalid decimal precision -2. Must be a positive integer.')]),
- ValidTestSchema(
- {"type": "bytes", "logicalType": "decimal", "precision": 2, "scale": 3},
- warnings=[schema.IgnoredLogicalType('Invalid decimal scale 3. Cannot be greater than precision 2.')]),
- ValidTestSchema(
- {"type": "fixed", "logicalType": "decimal", "name": "TestIgnored", "precision": -10, "scale": 2, "size": 5},
- warnings=[schema.IgnoredLogicalType('Invalid decimal precision -10. Must be a positive integer.')]),
- ValidTestSchema(
- {"type": "fixed", "logicalType": "decimal", "name": "TestIgnored", "scale": 2, "size": 5},
- warnings=[schema.IgnoredLogicalType('Invalid decimal precision None. Must be a positive integer.')]),
- ValidTestSchema(
- {"type": "fixed", "logicalType": "decimal", "name": "TestIgnored", "precision": 2, "scale": 3, "size": 2},
- warnings=[schema.IgnoredLogicalType('Invalid decimal scale 3. Cannot be greater than precision 2.')]),
- ValidTestSchema(
- {"type": "fixed", "logicalType": "decimal", "name": "TestIgnored", "precision": 311, "size": 129},
- warnings=[schema.IgnoredLogicalType('Invalid decimal precision 311. Max is 310.')]),
- ValidTestSchema(
- {"type": "float", "logicalType": "decimal", "precision": 2, "scale": 0},
- warnings=[schema.IgnoredLogicalType('Logical type decimal requires literal type bytes/fixed, not float.')]),
- ValidTestSchema(
- {"type": "int", "logicalType": "date1"},
- warnings=[schema.IgnoredLogicalType('Unknown date1, using int.')]),
- ValidTestSchema(
- {"type": "long", "logicalType": "date"},
- warnings=[schema.IgnoredLogicalType('Logical type date requires literal type int, not long.')]),
- ValidTestSchema(
- {"type": "int", "logicalType": "time-milis"},
- warnings=[schema.IgnoredLogicalType('Unknown time-milis, using int.')]),
- ValidTestSchema(
- {"type": "long", "logicalType": "time-millis"},
- warnings=[schema.IgnoredLogicalType('Logical type time-millis requires literal type int, not long.')]),
- ValidTestSchema(
- {"type": "long", "logicalType": "time-micro"},
- warnings=[schema.IgnoredLogicalType('Unknown time-micro, using long.')]),
- ValidTestSchema(
- {"type": "int", "logicalType": "time-micros"},
- warnings=[schema.IgnoredLogicalType('Logical type time-micros requires literal type long, not int.')]),
- ValidTestSchema(
- {"type": "long", "logicalType": "timestamp-milis"},
- warnings=[schema.IgnoredLogicalType('Unknown timestamp-milis, using long.')]),
- ValidTestSchema(
- {"type": "int", "logicalType": "timestamp-millis"},
- warnings=[schema.IgnoredLogicalType('Logical type timestamp-millis requires literal type long, not int.')]),
- ValidTestSchema(
- {"type": "long", "logicalType": "timestamp-micro"},
- warnings=[schema.IgnoredLogicalType('Unknown timestamp-micro, using long.')]),
- ValidTestSchema(
- {"type": "int", "logicalType": "timestamp-micros"},
- warnings=[schema.IgnoredLogicalType('Logical type timestamp-micros requires literal type long, not int.')])
+ ValidTestSchema(
+ {"type": "string", "logicalType": "uuid"},
+ warnings=[schema.IgnoredLogicalType('Unknown uuid, using string.')]),
+ ValidTestSchema(
+ {"type": "string", "logicalType": "unknown-logical-type"},
+ warnings=[schema.IgnoredLogicalType('Unknown unknown-logical-type, using string.')]),
+ ValidTestSchema(
+ {"type": "bytes", "logicalType": "decimal", "scale": 0},
+ warnings=[schema.IgnoredLogicalType('Invalid decimal precision None. Must be a positive integer.')]),
+ ValidTestSchema(
+ {"type": "bytes", "logicalType": "decimal", "precision": 2.4, "scale": 0},
+ warnings=[schema.IgnoredLogicalType('Invalid decimal precision 2.4. Must be a positive integer.')]),
+ ValidTestSchema(
+ {"type": "bytes", "logicalType": "decimal", "precision": 2, "scale": -2},
+ warnings=[schema.IgnoredLogicalType('Invalid decimal scale -2. Must be a positive integer.')]),
+ ValidTestSchema(
+ {"type": "bytes", "logicalType": "decimal", "precision": -2, "scale": 2},
+ warnings=[schema.IgnoredLogicalType('Invalid decimal precision -2. Must be a positive integer.')]),
+ ValidTestSchema(
+ {"type": "bytes", "logicalType": "decimal", "precision": 2, "scale": 3},
+ warnings=[schema.IgnoredLogicalType('Invalid decimal scale 3. Cannot be greater than precision 2.')]),
+ ValidTestSchema(
+ {"type": "fixed", "logicalType": "decimal", "name": "TestIgnored", "precision": -10, "scale": 2, "size": 5},
+ warnings=[schema.IgnoredLogicalType('Invalid decimal precision -10. Must be a positive integer.')]),
+ ValidTestSchema(
+ {"type": "fixed", "logicalType": "decimal", "name": "TestIgnored", "scale": 2, "size": 5},
+ warnings=[schema.IgnoredLogicalType('Invalid decimal precision None. Must be a positive integer.')]),
+ ValidTestSchema(
+ {"type": "fixed", "logicalType": "decimal", "name": "TestIgnored", "precision": 2, "scale": 3, "size": 2},
+ warnings=[schema.IgnoredLogicalType('Invalid decimal scale 3. Cannot be greater than precision 2.')]),
+ ValidTestSchema(
+ {"type": "fixed", "logicalType": "decimal", "name": "TestIgnored", "precision": 311, "size": 129},
+ warnings=[schema.IgnoredLogicalType('Invalid decimal precision 311. Max is 310.')]),
+ ValidTestSchema(
+ {"type": "float", "logicalType": "decimal", "precision": 2, "scale": 0},
+ warnings=[schema.IgnoredLogicalType('Logical type decimal requires literal type bytes/fixed, not float.')]),
+ ValidTestSchema(
+ {"type": "int", "logicalType": "date1"},
+ warnings=[schema.IgnoredLogicalType('Unknown date1, using int.')]),
+ ValidTestSchema(
+ {"type": "long", "logicalType": "date"},
+ warnings=[schema.IgnoredLogicalType('Logical type date requires literal type int, not long.')]),
+ ValidTestSchema(
+ {"type": "int", "logicalType": "time-milis"},
+ warnings=[schema.IgnoredLogicalType('Unknown time-milis, using int.')]),
+ ValidTestSchema(
+ {"type": "long", "logicalType": "time-millis"},
+ warnings=[schema.IgnoredLogicalType('Logical type time-millis requires literal type int, not long.')]),
+ ValidTestSchema(
+ {"type": "long", "logicalType": "time-micro"},
+ warnings=[schema.IgnoredLogicalType('Unknown time-micro, using long.')]),
+ ValidTestSchema(
+ {"type": "int", "logicalType": "time-micros"},
+ warnings=[schema.IgnoredLogicalType('Logical type time-micros requires literal type long, not int.')]),
+ ValidTestSchema(
+ {"type": "long", "logicalType": "timestamp-milis"},
+ warnings=[schema.IgnoredLogicalType('Unknown timestamp-milis, using long.')]),
+ ValidTestSchema(
+ {"type": "int", "logicalType": "timestamp-millis"},
+ warnings=[schema.IgnoredLogicalType('Logical type timestamp-millis requires literal type long, not int.')]),
+ ValidTestSchema(
+ {"type": "long", "logicalType": "timestamp-micro"},
+ warnings=[schema.IgnoredLogicalType('Unknown timestamp-micro, using long.')]),
+ ValidTestSchema(
+ {"type": "int", "logicalType": "timestamp-micros"},
+ warnings=[schema.IgnoredLogicalType('Logical type timestamp-micros requires literal type long, not int.')])
]
EXAMPLES = PRIMITIVE_EXAMPLES
@@ -323,277 +323,289 @@ INVALID_EXAMPLES = [e for e in EXAMPLES if not getattr(e, "valid", True)]
class TestMisc(unittest.TestCase):
- """Miscellaneous tests for schema"""
-
- def test_correct_recursive_extraction(self):
- """A recursive reference within a schema should be the same type every time."""
- s = schema.parse('{"type": "record", "name": "X", "fields": [{"name": "y", "type": {"type": "record", "name": "Y", "fields": [{"name": "Z", "type": "X"}]}}]}')
- t = schema.parse(str(s.fields[0].type))
- # If we've made it this far, the subschema was reasonably stringified; it ccould be reparsed.
- self.assertEqual("X", t.fields[0].type.name)
-
- def test_name_is_none(self):
- """When a name is None its namespace is None."""
- self.assertIsNone(schema.Name(None, None, None).fullname)
- self.assertIsNone(schema.Name(None, None, None).space)
-
- def test_name_not_empty_string(self):
- """A name cannot be the empty string."""
- self.assertRaises(schema.SchemaParseException, schema.Name, "", None, None)
-
- def test_name_space_specified(self):
- """Space combines with a name to become the fullname."""
- # name and namespace specified
- fullname = schema.Name('a', 'o.a.h', None).fullname
- self.assertEqual(fullname, 'o.a.h.a')
-
- def test_fullname_space_specified(self):
- """When name contains dots, namespace should be ignored."""
- fullname = schema.Name('a.b.c.d', 'o.a.h', None).fullname
- self.assertEqual(fullname, 'a.b.c.d')
-
- def test_name_default_specified(self):
- """Default space becomes the namespace when the namespace is None."""
- fullname = schema.Name('a', None, 'b.c.d').fullname
- self.assertEqual(fullname, 'b.c.d.a')
-
- def test_fullname_default_specified(self):
- """When a name contains dots, default space should be ignored."""
- fullname = schema.Name('a.b.c.d', None, 'o.a.h').fullname
- self.assertEqual(fullname, 'a.b.c.d')
-
- def test_fullname_space_default_specified(self):
- """When a name contains dots, namespace and default space should be ignored."""
- fullname = schema.Name('a.b.c.d', 'o.a.a', 'o.a.h').fullname
- self.assertEqual(fullname, 'a.b.c.d')
-
- def test_name_space_default_specified(self):
- """When name and space are specified, default space should be ignored."""
- fullname = schema.Name('a', 'o.a.a', 'o.a.h').fullname
- self.assertEqual(fullname, 'o.a.a.a')
-
- def test_equal_names(self):
- """Equality of names is defined on the fullname and is case-sensitive."""
- self.assertEqual(schema.Name('a.b.c.d', None, None), schema.Name('d', 'a.b.c', None))
- self.assertNotEqual(schema.Name('C.d', None, None), schema.Name('c.d', None, None))
-
- def test_invalid_name(self):
- """The name portion of a fullname, record field names, and enum symbols must:
- start with [A-Za-z_] and subsequently contain only [A-Za-z0-9_]"""
- self.assertRaises(schema.InvalidName, schema.Name, 'an especially spacey cowboy', None, None)
- self.assertRaises(schema.InvalidName, schema.Name, '99 problems but a name aint one', None, None)
-
- def test_null_namespace(self):
- """The empty string may be used as a namespace to indicate the null namespace."""
- name = schema.Name('name', "", None)
- self.assertEqual(name.fullname, "name")
- self.assertIsNone(name.space)
-
- def test_exception_is_not_swallowed_on_parse_error(self):
- """A specific exception message should appear on a json parse error."""
- self.assertRaisesRegexp(schema.SchemaParseException,
- r'Error parsing JSON: /not/a/real/file',
- schema.parse,
- '/not/a/real/file')
-
- def test_decimal_valid_type(self):
- fixed_decimal_schema = ValidTestSchema({
- "type": "fixed",
- "logicalType": "decimal",
- "name": "TestDecimal",
- "precision": 4,
- "scale": 2,
- "size": 2})
-
- bytes_decimal_schema = ValidTestSchema({
- "type": "bytes",
- "logicalType": "decimal",
- "precision": 4})
-
- fixed_decimal = fixed_decimal_schema.parse()
- self.assertEqual(4, fixed_decimal.get_prop('precision'))
- self.assertEqual(2, fixed_decimal.get_prop('scale'))
- self.assertEqual(2, fixed_decimal.get_prop('size'))
-
- bytes_decimal = bytes_decimal_schema.parse()
- self.assertEqual(4, bytes_decimal.get_prop('precision'))
- self.assertEqual(0, bytes_decimal.get_prop('scale'))
-
- def test_fixed_decimal_valid_max_precision(self):
- # An 8 byte number can represent any 18 digit number.
- fixed_decimal_schema = ValidTestSchema({
- "type": "fixed",
- "logicalType": "decimal",
- "name": "TestDecimal",
- "precision": 18,
- "scale": 0,
- "size": 8})
-
- fixed_decimal = fixed_decimal_schema.parse()
- self.assertIsInstance(fixed_decimal, schema.FixedSchema)
- self.assertIsInstance(fixed_decimal, schema.DecimalLogicalSchema)
-
- def test_fixed_decimal_invalid_max_precision(self):
- # An 8 byte number can't represent every 19 digit number, so the logical
- # type is not applied.
- fixed_decimal_schema = ValidTestSchema({
- "type": "fixed",
- "logicalType": "decimal",
- "name": "TestDecimal",
- "precision": 19,
- "scale": 0,
- "size": 8})
-
- fixed_decimal = fixed_decimal_schema.parse()
- self.assertIsInstance(fixed_decimal, schema.FixedSchema)
- self.assertNotIsInstance(fixed_decimal, schema.DecimalLogicalSchema)
-
- def test_parse_invalid_symbol(self):
- """Disabling enumschema symbol validation should allow invalid symbols to pass."""
- test_schema_string = json.dumps({
- "type": "enum", "name": "AVRO2174", "symbols": ["white space"]})
-
- try:
- case = schema.parse(test_schema_string, validate_enum_symbols=True)
- except schema.InvalidName:
- pass
- else:
- self.fail("When enum symbol validation is enabled, "
- "an invalid symbol should raise InvalidName.")
-
- try:
- case = schema.parse(test_schema_string, validate_enum_symbols=False)
- except schema.InvalidName:
- self.fail("When enum symbol validation is disabled, "
- "an invalid symbol should not raise InvalidName.")
+ """Miscellaneous tests for schema"""
+
+ def test_correct_recursive_extraction(self):
+ """A recursive reference within a schema should be the same type every time."""
+ s = schema.parse('''{
+ "type": "record",
+ "name": "X",
+ "fields": [{
+ "name": "y",
+ "type": {
+ "type": "record",
+ "name": "Y",
+ "fields": [{"name": "Z", "type": "X"}]}
+ }]
+ }''')
+ t = schema.parse(str(s.fields[0].type))
+ # If we've made it this far, the subschema was reasonably stringified; it ccould be reparsed.
+ self.assertEqual("X", t.fields[0].type.name)
+
+ def test_name_is_none(self):
+ """When a name is None its namespace is None."""
+ self.assertIsNone(schema.Name(None, None, None).fullname)
+ self.assertIsNone(schema.Name(None, None, None).space)
+
+ def test_name_not_empty_string(self):
+ """A name cannot be the empty string."""
+ self.assertRaises(schema.SchemaParseException, schema.Name, "", None, None)
+
+ def test_name_space_specified(self):
+ """Space combines with a name to become the fullname."""
+ # name and namespace specified
+ fullname = schema.Name('a', 'o.a.h', None).fullname
+ self.assertEqual(fullname, 'o.a.h.a')
+
+ def test_fullname_space_specified(self):
+ """When name contains dots, namespace should be ignored."""
+ fullname = schema.Name('a.b.c.d', 'o.a.h', None).fullname
+ self.assertEqual(fullname, 'a.b.c.d')
+
+ def test_name_default_specified(self):
+ """Default space becomes the namespace when the namespace is None."""
+ fullname = schema.Name('a', None, 'b.c.d').fullname
+ self.assertEqual(fullname, 'b.c.d.a')
+
+ def test_fullname_default_specified(self):
+ """When a name contains dots, default space should be ignored."""
+ fullname = schema.Name('a.b.c.d', None, 'o.a.h').fullname
+ self.assertEqual(fullname, 'a.b.c.d')
+
+ def test_fullname_space_default_specified(self):
+ """When a name contains dots, namespace and default space should be ignored."""
+ fullname = schema.Name('a.b.c.d', 'o.a.a', 'o.a.h').fullname
+ self.assertEqual(fullname, 'a.b.c.d')
+
+ def test_name_space_default_specified(self):
+ """When name and space are specified, default space should be ignored."""
+ fullname = schema.Name('a', 'o.a.a', 'o.a.h').fullname
+ self.assertEqual(fullname, 'o.a.a.a')
+
+ def test_equal_names(self):
+ """Equality of names is defined on the fullname and is case-sensitive."""
+ self.assertEqual(schema.Name('a.b.c.d', None, None), schema.Name('d', 'a.b.c', None))
+ self.assertNotEqual(schema.Name('C.d', None, None), schema.Name('c.d', None, None))
+
+ def test_invalid_name(self):
+ """The name portion of a fullname, record field names, and enum symbols must:
+ start with [A-Za-z_] and subsequently contain only [A-Za-z0-9_]"""
+ self.assertRaises(schema.InvalidName, schema.Name, 'an especially spacey cowboy', None, None)
+ self.assertRaises(schema.InvalidName, schema.Name, '99 problems but a name aint one', None, None)
+
+ def test_null_namespace(self):
+ """The empty string may be used as a namespace to indicate the null namespace."""
+ name = schema.Name('name', "", None)
+ self.assertEqual(name.fullname, "name")
+ self.assertIsNone(name.space)
+
+ def test_exception_is_not_swallowed_on_parse_error(self):
+ """A specific exception message should appear on a json parse error."""
+ self.assertRaisesRegexp(schema.SchemaParseException,
+ r'Error parsing JSON: /not/a/real/file',
+ schema.parse,
+ '/not/a/real/file')
+
+ def test_decimal_valid_type(self):
+ fixed_decimal_schema = ValidTestSchema({
+ "type": "fixed",
+ "logicalType": "decimal",
+ "name": "TestDecimal",
+ "precision": 4,
+ "scale": 2,
+ "size": 2})
+
+ bytes_decimal_schema = ValidTestSchema({
+ "type": "bytes",
+ "logicalType": "decimal",
+ "precision": 4})
+
+ fixed_decimal = fixed_decimal_schema.parse()
+ self.assertEqual(4, fixed_decimal.get_prop('precision'))
+ self.assertEqual(2, fixed_decimal.get_prop('scale'))
+ self.assertEqual(2, fixed_decimal.get_prop('size'))
+
+ bytes_decimal = bytes_decimal_schema.parse()
+ self.assertEqual(4, bytes_decimal.get_prop('precision'))
+ self.assertEqual(0, bytes_decimal.get_prop('scale'))
+
+ def test_fixed_decimal_valid_max_precision(self):
+ # An 8 byte number can represent any 18 digit number.
+ fixed_decimal_schema = ValidTestSchema({
+ "type": "fixed",
+ "logicalType": "decimal",
+ "name": "TestDecimal",
+ "precision": 18,
+ "scale": 0,
+ "size": 8})
+
+ fixed_decimal = fixed_decimal_schema.parse()
+ self.assertIsInstance(fixed_decimal, schema.FixedSchema)
+ self.assertIsInstance(fixed_decimal, schema.DecimalLogicalSchema)
+
+ def test_fixed_decimal_invalid_max_precision(self):
+ # An 8 byte number can't represent every 19 digit number, so the logical
+ # type is not applied.
+ fixed_decimal_schema = ValidTestSchema({
+ "type": "fixed",
+ "logicalType": "decimal",
+ "name": "TestDecimal",
+ "precision": 19,
+ "scale": 0,
+ "size": 8})
+
+ fixed_decimal = fixed_decimal_schema.parse()
+ self.assertIsInstance(fixed_decimal, schema.FixedSchema)
+ self.assertNotIsInstance(fixed_decimal, schema.DecimalLogicalSchema)
+
+ def test_parse_invalid_symbol(self):
+ """Disabling enumschema symbol validation should allow invalid symbols to pass."""
+ test_schema_string = json.dumps({
+ "type": "enum", "name": "AVRO2174", "symbols": ["white space"]})
+
+ try:
+ case = schema.parse(test_schema_string, validate_enum_symbols=True)
+ except schema.InvalidName:
+ pass
+ else:
+ self.fail("When enum symbol validation is enabled, "
+ "an invalid symbol should raise InvalidName.")
+
+ try:
+ case = schema.parse(test_schema_string, validate_enum_symbols=False)
+ except schema.InvalidName:
+ self.fail("When enum symbol validation is disabled, "
+ "an invalid symbol should not raise InvalidName.")
class SchemaParseTestCase(unittest.TestCase):
- """Enable generating parse test cases over all the valid and invalid example schema."""
-
- def __init__(self, test_schema):
- """Ignore the normal signature for unittest.TestCase because we are generating
- many test cases from this one class. This is safe as long as the autoloader
- ignores this class. The autoloader will ignore this class as long as it has
- no methods starting with `test_`.
- """
- super(SchemaParseTestCase, self).__init__(
- 'parse_valid' if test_schema.valid else 'parse_invalid')
- self.test_schema = test_schema
- # Never hide repeated warnings when running this test case.
- warnings.simplefilter("always")
-
- def parse_valid(self):
- """Parsing a valid schema should not error, but may contain warnings."""
- with warnings.catch_warnings(record=True) as actual_warnings:
- try:
- self.test_schema.parse()
- except (schema.AvroException, schema.SchemaParseException):
- self.fail("Valid schema failed to parse: {!s}".format(self.test_schema))
- actual_messages = [str(wmsg.message) for wmsg in actual_warnings]
- if self.test_schema.warnings:
- expected_messages = [str(w) for w in self.test_schema.warnings]
- self.assertEqual(actual_messages, expected_messages)
- else:
- self.assertEqual(actual_messages, [])
-
- def parse_invalid(self):
- """Parsing an invalid schema should error."""
- try:
- self.test_schema.parse()
- except (schema.AvroException, schema.SchemaParseException):
- pass
- else:
- self.fail("Invalid schema should not have parsed: {!s}".format(self.test_schema))
+ """Enable generating parse test cases over all the valid and invalid example schema."""
+
+ def __init__(self, test_schema):
+ """Ignore the normal signature for unittest.TestCase because we are generating
+ many test cases from this one class. This is safe as long as the autoloader
+ ignores this class. The autoloader will ignore this class as long as it has
+ no methods starting with `test_`.
+ """
+ super(SchemaParseTestCase, self).__init__(
+ 'parse_valid' if test_schema.valid else 'parse_invalid')
+ self.test_schema = test_schema
+ # Never hide repeated warnings when running this test case.
+ warnings.simplefilter("always")
+
+ def parse_valid(self):
+ """Parsing a valid schema should not error, but may contain warnings."""
+ with warnings.catch_warnings(record=True) as actual_warnings:
+ try:
+ self.test_schema.parse()
+ except (schema.AvroException, schema.SchemaParseException):
+ self.fail("Valid schema failed to parse: {!s}".format(self.test_schema))
+ actual_messages = [str(wmsg.message) for wmsg in actual_warnings]
+ if self.test_schema.warnings:
+ expected_messages = [str(w) for w in self.test_schema.warnings]
+ self.assertEqual(actual_messages, expected_messages)
+ else:
+ self.assertEqual(actual_messages, [])
+
+ def parse_invalid(self):
+ """Parsing an invalid schema should error."""
+ try:
+ self.test_schema.parse()
+ except (schema.AvroException, schema.SchemaParseException):
+ pass
+ else:
+ self.fail("Invalid schema should not have parsed: {!s}".format(self.test_schema))
class RoundTripParseTestCase(unittest.TestCase):
- """Enable generating round-trip parse test cases over all the valid test schema."""
-
- def __init__(self, test_schema):
- """Ignore the normal signature for unittest.TestCase because we are generating
- many test cases from this one class. This is safe as long as the autoloader
- ignores this class. The autoloader will ignore this class as long as it has
- no methods starting with `test_`.
- """
- super(RoundTripParseTestCase, self).__init__('parse_round_trip')
- self.test_schema = test_schema
-
- def parse_round_trip(self):
- """The string of a Schema should be parseable to the same Schema."""
- parsed = self.test_schema.parse()
- round_trip = schema.parse(str(parsed))
- self.assertEqual(parsed, round_trip)
+ """Enable generating round-trip parse test cases over all the valid test schema."""
+
+ def __init__(self, test_schema):
+ """Ignore the normal signature for unittest.TestCase because we are generating
+ many test cases from this one class. This is safe as long as the autoloader
+ ignores this class. The autoloader will ignore this class as long as it has
+ no methods starting with `test_`.
+ """
+ super(RoundTripParseTestCase, self).__init__('parse_round_trip')
+ self.test_schema = test_schema
+
+ def parse_round_trip(self):
+ """The string of a Schema should be parseable to the same Schema."""
+ parsed = self.test_schema.parse()
+ round_trip = schema.parse(str(parsed))
+ self.assertEqual(parsed, round_trip)
+
class DocAttributesTestCase(unittest.TestCase):
- """Enable generating document attribute test cases over all the document test schema."""
-
- def __init__(self, test_schema):
- """Ignore the normal signature for unittest.TestCase because we are generating
- many test cases from this one class. This is safe as long as the autoloader
- ignores this class. The autoloader will ignore this class as long as it has
- no methods starting with `test_`.
- """
- super(DocAttributesTestCase, self).__init__('check_doc_attributes')
- self.test_schema = test_schema
-
- def check_doc_attributes(self):
- """Documentation attributes should be preserved."""
- sch = self.test_schema.parse()
- self.assertIsNotNone(sch.doc, "Failed to preserve 'doc' in schema: {!s}".format(self.test_schema))
- if sch.type == 'record':
- for f in sch.fields:
- self.assertIsNotNone(f.doc, "Failed to preserve 'doc' in fields: {!s}".format(self.test_schema))
+ """Enable generating document attribute test cases over all the document test schema."""
+
+ def __init__(self, test_schema):
+ """Ignore the normal signature for unittest.TestCase because we are generating
+ many test cases from this one class. This is safe as long as the autoloader
+ ignores this class. The autoloader will ignore this class as long as it has
+ no methods starting with `test_`.
+ """
+ super(DocAttributesTestCase, self).__init__('check_doc_attributes')
+ self.test_schema = test_schema
+
+ def check_doc_attributes(self):
+ """Documentation attributes should be preserved."""
+ sch = self.test_schema.parse()
+ self.assertIsNotNone(sch.doc, "Failed to preserve 'doc' in schema: {!s}".format(self.test_schema))
+ if sch.type == 'record':
+ for f in sch.fields:
+ self.assertIsNotNone(f.doc, "Failed to preserve 'doc' in fields: {!s}".format(self.test_schema))
class OtherAttributesTestCase(unittest.TestCase):
- """Enable generating attribute test cases over all the other-prop test schema."""
- _type_map = {
- "cp_array": list,
- "cp_boolean": bool,
- "cp_float": float,
- "cp_int": int,
- "cp_null": type(None),
- "cp_object": dict,
- "cp_string": basestring,
- }
-
- def __init__(self, test_schema):
- """Ignore the normal signature for unittest.TestCase because we are generating
- many test cases from this one class. This is safe as long as the autoloader
- ignores this class. The autoloader will ignore this class as long as it has
- no methods starting with `test_`.
- """
- super(OtherAttributesTestCase, self).__init__('check_attributes')
- self.test_schema = test_schema
-
- def _check_props(self, props):
- for k, v in props.items():
- self.assertIsInstance(v, self._type_map[k])
-
- def check_attributes(self):
- """Other attributes and their types on a schema should be preserved."""
- sch = self.test_schema.parse()
- round_trip = schema.parse(str(sch))
- self.assertEqual(sch.other_props, round_trip.other_props,
- "Properties were not preserved in a round-trip parse.")
- self._check_props(sch.other_props)
- if sch.type == "record":
- field_props = [f.other_props for f in sch.fields if f.other_props]
- self.assertEqual(len(field_props), len(sch.fields))
- for p in field_props:
- self._check_props(p)
+ """Enable generating attribute test cases over all the other-prop test schema."""
+ _type_map = {
+ "cp_array": list,
+ "cp_boolean": bool,
+ "cp_float": float,
+ "cp_int": int,
+ "cp_null": type(None),
+ "cp_object": dict,
+ "cp_string": basestring,
+ }
+
+ def __init__(self, test_schema):
+ """Ignore the normal signature for unittest.TestCase because we are generating
+ many test cases from this one class. This is safe as long as the autoloader
+ ignores this class. The autoloader will ignore this class as long as it has
+ no methods starting with `test_`.
+ """
+ super(OtherAttributesTestCase, self).__init__('check_attributes')
+ self.test_schema = test_schema
+
+ def _check_props(self, props):
+ for k, v in props.items():
+ self.assertIsInstance(v, self._type_map[k])
+
+ def check_attributes(self):
+ """Other attributes and their types on a schema should be preserved."""
+ sch = self.test_schema.parse()
+ round_trip = schema.parse(str(sch))
+ self.assertEqual(sch.other_props, round_trip.other_props,
+ "Properties were not preserved in a round-trip parse.")
+ self._check_props(sch.other_props)
+ if sch.type == "record":
+ field_props = [f.other_props for f in sch.fields if f.other_props]
+ self.assertEqual(len(field_props), len(sch.fields))
+ for p in field_props:
+ self._check_props(p)
def load_tests(loader, default_tests, pattern):
- """Generate test cases across many test schema."""
- suite = unittest.TestSuite()
- suite.addTests(loader.loadTestsFromTestCase(TestMisc))
- suite.addTests(SchemaParseTestCase(ex) for ex in EXAMPLES)
- suite.addTests(RoundTripParseTestCase(ex) for ex in VALID_EXAMPLES)
- suite.addTests(DocAttributesTestCase(ex) for ex in DOC_EXAMPLES)
- suite.addTests(OtherAttributesTestCase(ex) for ex in OTHER_PROP_EXAMPLES)
- return suite
+ """Generate test cases across many test schema."""
+ suite = unittest.TestSuite()
+ suite.addTests(loader.loadTestsFromTestCase(TestMisc))
+ suite.addTests(SchemaParseTestCase(ex) for ex in EXAMPLES)
+ suite.addTests(RoundTripParseTestCase(ex) for ex in VALID_EXAMPLES)
+ suite.addTests(DocAttributesTestCase(ex) for ex in DOC_EXAMPLES)
+ suite.addTests(OtherAttributesTestCase(ex) for ex in OTHER_PROP_EXAMPLES)
+ return suite
+
if __name__ == '__main__':
- unittest.main()
+ unittest.main()
diff --git a/lang/py/avro/test/test_script.py b/lang/py/avro/test/test_script.py
index 7ebda4e..def98ef 100644
--- a/lang/py/avro/test/test_script.py
+++ b/lang/py/avro/test/test_script.py
@@ -35,9 +35,9 @@ from avro.datafile import DataFileWriter
from avro.io import DatumWriter
try:
- unicode
+ unicode
except NameError:
- unicode = str
+ unicode = str
NUM_RECORDS = 7
@@ -65,9 +65,11 @@ LOONIES = (
(unicode("foghorn"), unicode("leghorn"), unicode("rooster")),
)
+
def looney_records():
for f, l, t in LOONIES:
- yield {"first": f, "last" : l, "type" : t}
+ yield {"first": f, "last": l, "type": t}
+
SCRIPT = join(dirname(dirname(dirname(__file__))), "scripts", "avro")
@@ -77,6 +79,7 @@ _JSON_PRETTY = '''{
"type": "duck"
}'''
+
def gen_avro(filename):
schema = avro.schema.parse(SCHEMA)
fo = open(filename, "wb")
@@ -86,9 +89,11 @@ def gen_avro(filename):
writer.close()
fo.close()
+
def tempfile():
return NamedTemporaryFile(delete=False).name
+
class TestCat(unittest.TestCase):
def setUp(self):
self.avro_file = tempfile()
@@ -157,13 +162,14 @@ class TestCat(unittest.TestCase):
# Empty fields should get all
out = self._run('--fields', '')
assert json.loads(out[0]) == \
- {'first': unicode('daffy'), 'last': unicode('duck'),
- 'type': unicode('duck')}
+ {'first': unicode('daffy'), 'last': unicode('duck'),
+ 'type': unicode('duck')}
# Non existing fields are ignored
out = self._run('--fields', 'first,last,age')
assert json.loads(out[0]) == {'first': unicode('daffy'), 'last': unicode('duck')}
+
class TestWrite(unittest.TestCase):
def setUp(self):
self.json_file = tempfile() + ".json"
@@ -207,7 +213,7 @@ class TestWrite(unittest.TestCase):
def format_check(self, format, filename):
tmp = tempfile()
with open(tmp, "wb") as fo:
- self._run(filename, "-f", format, stdout=fo)
+ self._run(filename, "-f", format, stdout=fo)
records = self.load_avro(tmp)
assert len(records) == NUM_RECORDS
diff --git a/lang/py/avro/test/test_tether_task.py b/lang/py/avro/test/test_tether_task.py
index 8856281..a9567f3 100644
--- a/lang/py/avro/test/test_tether_task.py
+++ b/lang/py/avro/test/test_tether_task.py
@@ -34,88 +34,90 @@ import avro.tether.util
from avro import schema, tether
try:
- unicode
+ unicode
except NameError:
- unicode = str
+ unicode = str
class TestTetherTask(unittest.TestCase):
- """
- TODO: We should validate the the server response by looking at stdout
- """
- def test_tether_task(self):
"""
- Test that the tether_task is working. We run the mock_tether_parent in a separate
- subprocess
+ TODO: We should validate the the server response by looking at stdout
"""
- task = avro.test.word_count_task.WordCountTask()
- proc = None
- pyfile = avro.test.mock_tether_parent.__file__
- server_port = avro.tether.util.find_port()
- input_port = avro.tether.util.find_port()
- try:
- # launch the server in a separate process
- proc = subprocess.Popen([sys.executable, pyfile, "start_server", str(server_port)])
-
- print("Mock server started process pid={}".format(proc.pid))
-
- # Possible race condition? open tries to connect to the subprocess before the subprocess is fully started
- # so we give the subprocess time to start up
- time.sleep(1)
- task.open(input_port, clientPort=server_port)
-
- # TODO: We should validate that open worked by grabbing the STDOUT of the subproces
- # and ensuring that it outputted the correct message.
-
- #***************************************************************
- # Test the mapper
- task.configure(
- avro.tether.tether_task.TaskType.MAP,
- str(task.inschema),
- str(task.midschema)
- )
-
- # Serialize some data so we can send it to the input function
- datum = unicode("This is a line of text")
- writer = io.BytesIO()
- encoder = avro.io.BinaryEncoder(writer)
- datum_writer = avro.io.DatumWriter(task.inschema)
- datum_writer.write(datum, encoder)
-
- writer.seek(0)
- data = writer.read()
-
- # Call input to simulate calling map
- task.input(data, 1)
-
- # Test the reducer
- task.configure(
- avro.tether.tether_task.TaskType.REDUCE,
- str(task.midschema),
- str(task.outschema)
- )
-
- # Serialize some data so we can send it to the input function
- datum = {"key": unicode("word"), "value": 2}
- writer = io.BytesIO()
- encoder = avro.io.BinaryEncoder(writer)
- datum_writer = avro.io.DatumWriter(task.midschema)
- datum_writer.write(datum, encoder)
-
- writer.seek(0)
- data = writer.read()
-
- # Call input to simulate calling reduce
- task.input(data, 1)
-
- task.complete()
-
- # try a status
- task.status(unicode("Status message"))
- finally:
- # close the process
- if not(proc is None):
- proc.kill()
+
+ def test_tether_task(self):
+ """
+ Test that the tether_task is working. We run the mock_tether_parent in a separate
+ subprocess
+ """
+ task = avro.test.word_count_task.WordCountTask()
+ proc = None
+ pyfile = avro.test.mock_tether_parent.__file__
+ server_port = avro.tether.util.find_port()
+ input_port = avro.tether.util.find_port()
+ try:
+ # launch the server in a separate process
+ proc = subprocess.Popen([sys.executable, pyfile, "start_server", str(server_port)])
+
+ print("Mock server started process pid={}".format(proc.pid))
+
+ # Possible race condition? open tries to connect to the subprocess before the subprocess is fully started
+ # so we give the subprocess time to start up
+ time.sleep(1)
+ task.open(input_port, clientPort=server_port)
+
+ # TODO: We should validate that open worked by grabbing the STDOUT of the subproces
+ # and ensuring that it outputted the correct message.
+
+ # ***************************************************************
+ # Test the mapper
+ task.configure(
+ avro.tether.tether_task.TaskType.MAP,
+ str(task.inschema),
+ str(task.midschema)
+ )
+
+ # Serialize some data so we can send it to the input function
+ datum = unicode("This is a line of text")
+ writer = io.BytesIO()
+ encoder = avro.io.BinaryEncoder(writer)
+ datum_writer = avro.io.DatumWriter(task.inschema)
+ datum_writer.write(datum, encoder)
+
+ writer.seek(0)
+ data = writer.read()
+
+ # Call input to simulate calling map
+ task.input(data, 1)
+
+ # Test the reducer
+ task.configure(
+ avro.tether.tether_task.TaskType.REDUCE,
+ str(task.midschema),
+ str(task.outschema)
+ )
+
+ # Serialize some data so we can send it to the input function
+ datum = {"key": unicode("word"), "value": 2}
+ writer = io.BytesIO()
+ encoder = avro.io.BinaryEncoder(writer)
+ datum_writer = avro.io.DatumWriter(task.midschema)
+ datum_writer.write(datum, encoder)
+
+ writer.seek(0)
+ data = writer.read()
+
+ # Call input to simulate calling reduce
+ task.input(data, 1)
+
+ task.complete()
+
+ # try a status
+ task.status(unicode("Status message"))
+ finally:
+ # close the process
+ if not(proc is None):
+ proc.kill()
+
if __name__ == '__main__':
- unittest.main()
+ unittest.main()
diff --git a/lang/py/avro/test/test_tether_task_runner.py b/lang/py/avro/test/test_tether_task_runner.py
index 93827a0..10582d3 100644
--- a/lang/py/avro/test/test_tether_task_runner.py
+++ b/lang/py/avro/test/test_tether_task_runner.py
@@ -35,164 +35,158 @@ import avro.tether.tether_task_runner
import avro.tether.util
try:
- unicode
+ unicode
except NameError:
- unicode = str
+ unicode = str
class TestTetherTaskRunner(unittest.TestCase):
- """unit test for a tethered task runner."""
+ """unit test for a tethered task runner."""
+
+ def test1(self):
+ # set the logging level to debug so that debug messages are printed
+ logging.basicConfig(level=logging.DEBUG)
+
+ proc = None
+ try:
+ # launch the server in a separate process
+ env = dict()
+ env["PYTHONPATH"] = ':'.join(sys.path)
+ parent_port = avro.tether.util.find_port()
+
+ pyfile = avro.test.mock_tether_parent.__file__
+ proc = subprocess.Popen([sys.executable, pyfile, "start_server", "{0}".format(parent_port)])
+ input_port = avro.tether.util.find_port()
+
+ print("Mock server started process pid={0}".format(proc.pid))
+ # Possible race condition? open tries to connect to the subprocess before the subprocess is fully started
+ # so we give the subprocess time to start up
+ time.sleep(1)
+
+ runner = avro.tether.tether_task_runner.TaskRunner(avro.test.word_count_task.WordCountTask())
+
+ runner.start(outputport=parent_port, join=False)
+
+ # Test sending various messages to the server and ensuring they are processed correctly
+ requestor = avro.tether.tether_task.HTTPRequestor(
+ "localhost", runner.server.server_address[1], avro.tether.tether_task.inputProtocol)
+
+ # TODO: We should validate that open worked by grabbing the STDOUT of the subproces
+ # and ensuring that it outputted the correct message.
+
+ # Test the mapper
+ requestor.request("configure", {
+ "taskType": avro.tether.tether_task.TaskType.MAP,
+ "inSchema": unicode(str(runner.task.inschema)),
+ "outSchema": unicode(str(runner.task.midschema))
+ })
+
+ # Serialize some data so we can send it to the input function
+ datum = unicode("This is a line of text")
+ writer = io.BytesIO()
+ encoder = avro.io.BinaryEncoder(writer)
+ datum_writer = avro.io.DatumWriter(runner.task.inschema)
+ datum_writer.write(datum, encoder)
+
+ writer.seek(0)
+ data = writer.read()
+
+ # Call input to simulate calling map
+ requestor.request("input", {"data": data, "count": 1})
+
+ # Test the reducer
+ requestor.request("configure", {
+ "taskType": avro.tether.tether_task.TaskType.REDUCE,
+ "inSchema": unicode(str(runner.task.midschema)),
+ "outSchema": unicode(str(runner.task.outschema))}
+ )
+
+ # Serialize some data so we can send it to the input function
+ datum = {"key": unicode("word"), "value": 2}
+ writer = io.BytesIO()
+ encoder = avro.io.BinaryEncoder(writer)
+ datum_writer = avro.io.DatumWriter(runner.task.midschema)
+ datum_writer.write(datum, encoder)
+
+ writer.seek(0)
+ data = writer.read()
+
+ # Call input to simulate calling reduce
+ requestor.request("input", {"data": data, "count": 1})
+
+ requestor.request("complete", {})
+
+ runner.task.ready_for_shutdown.wait()
+ runner.server.shutdown()
+ # time.sleep(2)
+ # runner.server.shutdown()
+
+ sthread = runner.sthread
+
+ # Possible race condition?
+ time.sleep(1)
+
+ # make sure the other thread terminated
+ self.assertFalse(sthread.isAlive())
+
+ # shutdown the logging
+ logging.shutdown()
+
+ except Exception as e:
+ raise
+ finally:
+ # close the process
+ if not(proc is None):
+ proc.kill()
+
+ def test2(self):
+ """
+ In this test we want to make sure that when we run "tether_task_runner.py"
+ as our main script everything works as expected. We do this by using subprocess to run it
+ in a separate thread.
+ """
+ proc = None
+
+ runnerproc = None
+ try:
+ # launch the server in a separate process
+ env = dict()
+ env["PYTHONPATH"] = ':'.join(sys.path)
+ parent_port = avro.tether.util.find_port()
+
+ pyfile = avro.test.mock_tether_parent.__file__
+ proc = subprocess.Popen([sys.executable, pyfile, "start_server", "{0}".format(parent_port)])
+
+ # Possible race condition? when we start tether_task_runner it will call
+ # open tries to connect to the subprocess before the subprocess is fully started
+ # so we give the subprocess time to start up
+ time.sleep(1)
+
+ # start the tether_task_runner in a separate process
+ env = {"AVRO_TETHER_OUTPUT_PORT": "{0}".format(parent_port)}
+ env["PYTHONPATH"] = ':'.join(sys.path)
+
+ runnerproc = subprocess.Popen([sys.executable, avro.tether.tether_task_runner.__file__,
+ "avro.test.word_count_task.WordCountTask"], env=env)
+
+ # possible race condition wait for the process to start
+ time.sleep(1)
+
+ print("Mock server started process pid={0}".format(proc.pid))
+ # Possible race condition? open tries to connect to the subprocess before the subprocess is fully started
+ # so we give the subprocess time to start up
+ time.sleep(1)
+
+ except Exception as e:
+ raise
+ finally:
+ # close the process
+ if not(runnerproc is None):
+ runnerproc.kill()
+
+ if not(proc is None):
+ proc.kill()
+
- def test1(self):
- # set the logging level to debug so that debug messages are printed
- logging.basicConfig(level=logging.DEBUG)
-
- proc=None
- try:
- # launch the server in a separate process
- env=dict()
- env["PYTHONPATH"]=':'.join(sys.path)
- parent_port = avro.tether.util.find_port()
-
- pyfile=avro.test.mock_tether_parent.__file__
- proc=subprocess.Popen([sys.executable, pyfile,"start_server","{0}".format(parent_port)])
- input_port = avro.tether.util.find_port()
-
- print("Mock server started process pid={0}".format(proc.pid))
- # Possible race condition? open tries to connect to the subprocess before the subprocess is fully started
- # so we give the subprocess time to start up
- time.sleep(1)
-
- runner = avro.tether.tether_task_runner.TaskRunner(avro.test.word_count_task.WordCountTask())
-
- runner.start(outputport=parent_port,join=False)
-
- # Test sending various messages to the server and ensuring they are processed correctly
- requestor = avro.tether.tether_task.HTTPRequestor(
- "localhost", runner.server.server_address[1], avro.tether.tether_task.inputProtocol)
-
- # TODO: We should validate that open worked by grabbing the STDOUT of the subproces
- # and ensuring that it outputted the correct message.
-
- # Test the mapper
- requestor.request("configure", {
- "taskType": avro.tether.tether_task.TaskType.MAP,
- "inSchema": unicode(str(runner.task.inschema)),
- "outSchema": unicode(str(runner.task.midschema))
- })
-
- # Serialize some data so we can send it to the input function
- datum = unicode("This is a line of text")
- writer = io.BytesIO()
- encoder = avro.io.BinaryEncoder(writer)
- datum_writer = avro.io.DatumWriter(runner.task.inschema)
- datum_writer.write(datum, encoder)
-
- writer.seek(0)
- data=writer.read()
-
-
- # Call input to simulate calling map
- requestor.request("input",{"data":data,"count":1})
-
- # Test the reducer
- requestor.request("configure", {
- "taskType": avro.tether.tether_task.TaskType.REDUCE,
- "inSchema": unicode(str(runner.task.midschema)),
- "outSchema": unicode(str(runner.task.outschema))}
- )
-
- #Serialize some data so we can send it to the input function
- datum = {"key": unicode("word"), "value": 2}
- writer = io.BytesIO()
- encoder = avro.io.BinaryEncoder(writer)
- datum_writer = avro.io.DatumWriter(runner.task.midschema)
- datum_writer.write(datum, encoder)
-
- writer.seek(0)
- data=writer.read()
-
-
- #Call input to simulate calling reduce
- requestor.request("input",{"data":data,"count":1})
-
- requestor.request("complete",{})
-
-
- runner.task.ready_for_shutdown.wait()
- runner.server.shutdown()
- #time.sleep(2)
- #runner.server.shutdown()
-
- sthread=runner.sthread
-
- #Possible race condition?
- time.sleep(1)
-
- #make sure the other thread terminated
- self.assertFalse(sthread.isAlive())
-
- #shutdown the logging
- logging.shutdown()
-
- except Exception as e:
- raise
- finally:
- #close the process
- if not(proc is None):
- proc.kill()
-
-
- def test2(self):
- """
- In this test we want to make sure that when we run "tether_task_runner.py"
- as our main script everything works as expected. We do this by using subprocess to run it
- in a separate thread.
- """
- proc=None
-
- runnerproc=None
- try:
- #launch the server in a separate process
- env=dict()
- env["PYTHONPATH"]=':'.join(sys.path)
- parent_port = avro.tether.util.find_port()
-
- pyfile=avro.test.mock_tether_parent.__file__
- proc=subprocess.Popen([sys.executable, pyfile,"start_server","{0}".format(parent_port)])
-
- #Possible race condition? when we start tether_task_runner it will call
- # open tries to connect to the subprocess before the subprocess is fully started
- #so we give the subprocess time to start up
- time.sleep(1)
-
-
- #start the tether_task_runner in a separate process
- env={"AVRO_TETHER_OUTPUT_PORT":"{0}".format(parent_port)}
- env["PYTHONPATH"]=':'.join(sys.path)
-
- runnerproc = subprocess.Popen([sys.executable, avro.tether.tether_task_runner.__file__, "avro.test.word_count_task.WordCountTask"], env=env)
-
- #possible race condition wait for the process to start
- time.sleep(1)
-
-
-
- print("Mock server started process pid={0}".format(proc.pid))
- #Possible race condition? open tries to connect to the subprocess before the subprocess is fully started
- #so we give the subprocess time to start up
- time.sleep(1)
-
-
- except Exception as e:
- raise
- finally:
- #close the process
- if not(runnerproc is None):
- runnerproc.kill()
-
- if not(proc is None):
- proc.kill()
-
-if __name__==("__main__"):
- unittest.main()
+if __name__ == ("__main__"):
+ unittest.main()
diff --git a/lang/py/avro/test/test_tether_word_count.py b/lang/py/avro/test/test_tether_word_count.py
index 50cd978..c2ec47c 100644
--- a/lang/py/avro/test/test_tether_word_count.py
+++ b/lang/py/avro/test/test_tether_word_count.py
@@ -36,20 +36,22 @@ import avro.schema
import avro.tether.tether_task_runner
try:
- unicode
+ unicode
except NameError:
- unicode = str
+ unicode = str
_AVRO_DIR = os.path.abspath(os.path.dirname(avro.__file__))
+
def _version():
- with open(os.path.join(_AVRO_DIR, 'VERSION.txt')) as v:
- # Convert it back to the java version
- return v.read().strip().replace('+', '-')
+ with open(os.path.join(_AVRO_DIR, 'VERSION.txt')) as v:
+ # Convert it back to the java version
+ return v.read().strip().replace('+', '-')
+
_AVRO_VERSION = _version()
_JAR_PATH = os.path.join(os.path.dirname(os.path.dirname(_AVRO_DIR)),
- "java", "tools", "target", "avro-tools-{}.jar".format(_AVRO_VERSION))
+ "java", "tools", "target", "avro-tools-{}.jar".format(_AVRO_VERSION))
_LINES = (unicode("the quick brown fox jumps over the lazy dog"),
unicode("the cow jumps over the moon"),
@@ -70,96 +72,100 @@ _PYTHON_PATH = os.pathsep.join([os.path.dirname(os.path.dirname(avro.__file__)),
def _has_java():
- """Detect if this system has a usable java installed.
+ """Detect if this system has a usable java installed.
- On most systems, this is just checking if `java` is in the PATH.
+ On most systems, this is just checking if `java` is in the PATH.
- But macos always has a /usr/bin/java, which does not mean java is installed. If you invoke java on macos and java is not installed, macos will spawn a popup telling you how to install java. This code does additional work around that to be completely automatic.
- """
- if platform.system() == "Darwin":
- try:
- output = subprocess.check_output("/usr/libexec/java_home", stderr=subprocess.STDOUT)
- except subprocess.CalledProcessError as e:
- output = e.output
- return (b"No Java runtime present" not in output)
- return bool(distutils.spawn.find_executable("java"))
+ But macos always has a /usr/bin/java, which does not mean java is installed.
+ If you invoke java on macos and java is not installed, macos will spawn a popup
+ telling you how to install java. This code does additional work around that
+ to be completely automatic.
+ """
+ if platform.system() == "Darwin":
+ try:
+ output = subprocess.check_output("/usr/libexec/java_home", stderr=subprocess.STDOUT)
+ except subprocess.CalledProcessError as e:
+ output = e.output
+ return (b"No Java runtime present" not in output)
+ return bool(distutils.spawn.find_executable("java"))
@unittest.skipUnless(_has_java(), "No Java runtime present")
@unittest.skipUnless(os.path.exists(_JAR_PATH), "{} not found".format(_JAR_PATH))
... 2679 lines suppressed ...