You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by bh...@apache.org on 2022/07/01 19:04:39 UTC
[beam] branch master updated: Python: Use RowTypeConstraint for normalizing all schema-inferrable user types (#22066)
This is an automated email from the ASF dual-hosted git repository.
bhulette pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git
The following commit(s) were added to refs/heads/master by this push:
new 07658c6cce9 Python: Use RowTypeConstraint for normalizing all schema-inferrable user types (#22066)
07658c6cce9 is described below
commit 07658c6cce9dd82c59b5e521e08150b2df443391
Author: Brian Hulette <bh...@google.com>
AuthorDate: Fri Jul 1 12:04:31 2022 -0700
Python: Use RowTypeConstraint for normalizing all schema-inferrable user types (#22066)
* Use RowTypeConstraint for normalizing schema-inferrable user types
* Fix user_type generation for nested rows
* Fix docstring
* Backout sdk_options changes
* fix import
* Backout sql_test changes
* cleanup
* Update sdks/python/apache_beam/typehints/schemas.py
Co-authored-by: Andy Ye <an...@gmail.com>
* Address review comments
* yapf
* Clarify docstring
Co-authored-by: Andy Ye <an...@gmail.com>
---
sdks/python/apache_beam/typehints/row_type.py | 85 +++++++++++-
sdks/python/apache_beam/typehints/schemas.py | 89 +++++++------
sdks/python/apache_beam/typehints/schemas_test.py | 146 ++++++++++++---------
.../apache_beam/typehints/trivial_inference.py | 9 +-
4 files changed, 218 insertions(+), 111 deletions(-)
diff --git a/sdks/python/apache_beam/typehints/row_type.py b/sdks/python/apache_beam/typehints/row_type.py
index f7098fb9067..37e52c62a2e 100644
--- a/sdks/python/apache_beam/typehints/row_type.py
+++ b/sdks/python/apache_beam/typehints/row_type.py
@@ -17,19 +17,98 @@
# pytype: skip-file
+from __future__ import annotations
+
+from typing import List
+from typing import Optional
+from typing import Sequence
+from typing import Tuple
+
from apache_beam.typehints import typehints
+from apache_beam.typehints.native_type_compatibility import match_is_named_tuple
+
+# Name of the attribute added to user types (existing and generated) to store
+# the corresponding schema ID
+_BEAM_SCHEMA_ID = "_beam_schema_id"
class RowTypeConstraint(typehints.TypeConstraint):
- def __init__(self, fields):
- self._fields = tuple(fields)
+ def __init__(self, fields: List[Tuple[str, type]], user_type=None):
+ """For internal use only, no backwards comatibility guaratees. See
+ https://beam.apache.org/documentation/programming-guide/#schemas-for-pl-types
+ for guidance on creating PCollections with inferred schemas.
+
+ Note RowTypeConstraint does not currently store arbitrary functions for
+ converting to/from the user type. Instead, we only support ``NamedTuple``
+ user types and make the follow assumptions:
+
+ - The user type can be constructed with field values as arguments in order
+ (i.e. ``constructor(*field_values)``).
+ - Field values can be accessed from instances of the user type by attribute
+ (i.e. with ``getattr(obj, field_name)``).
+
+ In the future we will add support for dataclasses
+ ([#22085](https://github.com/apache/beam/issues/22085)) which also satisfy
+ these assumptions.
+
+ The RowTypeConstraint constructor should not be called directly (even
+ internally to Beam). Prefer static methods ``from_user_type`` or
+ ``from_fields``.
+
+ Parameters:
+ fields: a list of (name, type) tuples, representing the schema inferred
+ from user_type.
+ user_type: constructor for a user type (e.g. NamedTuple class) that is
+ used to represent this schema in user code.
+ """
+ # Recursively wrap row types in a RowTypeConstraint
+ self._fields = tuple((name, RowTypeConstraint.from_user_type(typ) or typ)
+ for name,
+ typ in fields)
+
+ self._user_type = user_type
+
+ # Note schema ID can be None if the schema is not registered yet.
+ # Currently registration happens when converting to schema protos, in
+ # apache_beam.typehints.schemas
+ if self._user_type is not None:
+ self._schema_id = getattr(self._user_type, _BEAM_SCHEMA_ID, None)
+ else:
+ self._schema_id = None
+
+ @staticmethod
+ def from_user_type(user_type: type) -> Optional[RowTypeConstraint]:
+ if match_is_named_tuple(user_type):
+ fields = [(name, user_type.__annotations__[name])
+ for name in user_type._fields]
+
+ return RowTypeConstraint(fields=fields, user_type=user_type)
+
+ return None
+
+ @staticmethod
+ def from_fields(fields: Sequence[Tuple[str, type]]) -> RowTypeConstraint:
+ return RowTypeConstraint(fields=fields, user_type=None)
+
+ @property
+ def user_type(self):
+ return self._user_type
+
+ def set_schema_id(self, schema_id):
+ self._schema_id = schema_id
+ if self._user_type is not None:
+ setattr(self._user_type, _BEAM_SCHEMA_ID, self._schema_id)
+
+ @property
+ def schema_id(self):
+ return self._schema_id
def _consistent_with_check_(self, sub):
return self == sub
def type_check(self, instance):
from apache_beam import Row
- return isinstance(instance, Row)
+ return isinstance(instance, (Row, self._user_type))
def _inner_types(self):
"""Iterates over the inner types of the composite type."""
diff --git a/sdks/python/apache_beam/typehints/schemas.py b/sdks/python/apache_beam/typehints/schemas.py
index 74ace0e319a..6f56d2a6fa4 100644
--- a/sdks/python/apache_beam/typehints/schemas.py
+++ b/sdks/python/apache_beam/typehints/schemas.py
@@ -145,10 +145,6 @@ PRIMITIVE_TO_ATOMIC_TYPE.update({
float: schema_pb2.DOUBLE,
})
-# Name of the attribute added to user types (existing and generated) to store
-# the corresponding schema ID
-_BEAM_SCHEMA_ID = "_beam_schema_id"
-
def named_fields_to_schema(names_and_types):
# type: (Union[Dict[str, type], Sequence[Tuple[str, type]]]) -> schema_pb2.Schema # noqa: F821
@@ -191,46 +187,43 @@ class SchemaTranslation(object):
if isinstance(type_, schema_pb2.Schema):
return schema_pb2.FieldType(row_type=schema_pb2.RowType(schema=type_))
- elif match_is_named_tuple(type_):
- if hasattr(type_, _BEAM_SCHEMA_ID):
- schema_id = getattr(type_, _BEAM_SCHEMA_ID)
- schema = self.schema_registry.get_schema_by_id(
- getattr(type_, _BEAM_SCHEMA_ID))
- else:
- schema_id = self.schema_registry.generate_new_id()
+ if isinstance(type_, row_type.RowTypeConstraint):
+ if type_.schema_id is None:
+ schema_id = SCHEMA_REGISTRY.generate_new_id()
+ type_.set_schema_id(schema_id)
schema = None
- setattr(type_, _BEAM_SCHEMA_ID, schema_id)
+ else:
+ schema_id = type_.schema_id
+ schema = self.schema_registry.get_schema_by_id(schema_id)
if schema is None:
- fields = [
- schema_pb2.Field(
- name=name,
- type=typing_to_runner_api(type_.__annotations__[name]))
- for name in type_._fields
- ]
- schema = schema_pb2.Schema(fields=fields, id=schema_id)
- self.schema_registry.add(type_, schema)
-
+ # Either user_type was not annotated with a schema id, or there was
+ # no schema in the registry with the id. The latter should only happen
+ # in tests.
+ # Either way, we need to generate a new schema proto.
+ schema = schema_pb2.Schema(
+ fields=[
+ schema_pb2.Field(
+ name=name, type=self.typing_to_runner_api(field_type))
+ for (name, field_type) in type_._fields
+ ],
+ id=schema_id)
+ self.schema_registry.add(type_.user_type, schema)
return schema_pb2.FieldType(row_type=schema_pb2.RowType(schema=schema))
-
- elif isinstance(type_, row_type.RowTypeConstraint):
- return schema_pb2.FieldType(
- row_type=schema_pb2.RowType(
- schema=schema_pb2.Schema(
- fields=[
- schema_pb2.Field(
- name=name, type=typing_to_runner_api(field_type))
- for (name, field_type) in type_._fields
- ],
- id=self.schema_registry.generate_new_id())))
+ else:
+ # See if this is coercible to a RowTypeConstraint (e.g. a NamedTuple or
+ # dataclass)
+ row_type_constraint = row_type.RowTypeConstraint.from_user_type(type_)
+ if row_type_constraint is not None:
+ return self.typing_to_runner_api(row_type_constraint)
# All concrete types (other than NamedTuple sub-classes) should map to
# a supported primitive type.
- elif type_ in PRIMITIVE_TO_ATOMIC_TYPE:
+ if type_ in PRIMITIVE_TO_ATOMIC_TYPE:
return schema_pb2.FieldType(atomic_type=PRIMITIVE_TO_ATOMIC_TYPE[type_])
elif _match_is_exactly_mapping(type_):
- key_type, value_type = map(typing_to_runner_api, _get_args(type_))
+ key_type, value_type = map(self.typing_to_runner_api, _get_args(type_))
return schema_pb2.FieldType(
map_type=schema_pb2.MapType(key_type=key_type, value_type=value_type))
@@ -238,17 +231,17 @@ class SchemaTranslation(object):
# It's possible that a user passes us Optional[Optional[T]], but in python
# typing this is indistinguishable from Optional[T] - both resolve to
# Union[T, None] - so there's no need to check for that case here.
- result = typing_to_runner_api(extract_optional_type(type_))
+ result = self.typing_to_runner_api(extract_optional_type(type_))
result.nullable = True
return result
elif _safe_issubclass(type_, Sequence):
- element_type = typing_to_runner_api(_get_args(type_)[0])
+ element_type = self.typing_to_runner_api(_get_args(type_)[0])
return schema_pb2.FieldType(
array_type=schema_pb2.ArrayType(element_type=element_type))
elif _safe_issubclass(type_, Mapping):
- key_type, value_type = map(typing_to_runner_api, _get_args(type_))
+ key_type, value_type = map(self.typing_to_runner_api, _get_args(type_))
return schema_pb2.FieldType(
map_type=schema_pb2.MapType(key_type=key_type, value_type=value_type))
@@ -264,7 +257,7 @@ class SchemaTranslation(object):
return schema_pb2.FieldType(
logical_type=schema_pb2.LogicalType(
urn=logical_type.urn(),
- representation=typing_to_runner_api(
+ representation=self.typing_to_runner_api(
logical_type.representation_type())))
def typing_from_runner_api(
@@ -297,8 +290,12 @@ class SchemaTranslation(object):
self.typing_from_runner_api(fieldtype_proto.map_type.value_type)]
elif type_info == "row_type":
schema = fieldtype_proto.row_type.schema
+ # First look for user type in the registry
user_type = self.schema_registry.get_typing_by_id(schema.id)
+
if user_type is None:
+ # If not in SDK options (the coder likely came from another SDK),
+ # generate a NamedTuple type to use.
from apache_beam import coders
type_name = 'BeamSchema_{}'.format(schema.id.replace('-', '_'))
@@ -307,6 +304,8 @@ class SchemaTranslation(object):
for field in schema.fields:
try:
field_py_type = self.typing_from_runner_api(field.type)
+ if isinstance(field_py_type, row_type.RowTypeConstraint):
+ field_py_type = field_py_type.user_type
except ValueError as e:
raise ValueError(
"Failed to decode schema due to an issue with Field proto:\n\n"
@@ -316,8 +315,6 @@ class SchemaTranslation(object):
user_type = NamedTuple(type_name, subfields)
- setattr(user_type, _BEAM_SCHEMA_ID, schema.id)
-
# Define a reduce function, otherwise these types can't be pickled
# (See BEAM-9574)
def __reduce__(self):
@@ -329,7 +326,11 @@ class SchemaTranslation(object):
self.schema_registry.add(user_type, schema)
coders.registry.register_coder(user_type, coders.RowCoder)
- return user_type
+ result = row_type.RowTypeConstraint.from_user_type(user_type)
+ result.set_schema_id(schema.id)
+ return result
+ else:
+ return row_type.RowTypeConstraint.from_user_type(user_type)
elif type_info == "logical_type":
if fieldtype_proto.logical_type.urn == PYTHON_ANY_URN:
@@ -349,9 +350,11 @@ def _hydrate_namedtuple_instance(encoded_schema, values):
def named_tuple_from_schema(
schema, schema_registry: SchemaTypeRegistry = SCHEMA_REGISTRY) -> type:
- return typing_from_runner_api(
+ row_type_constraint = typing_from_runner_api(
schema_pb2.FieldType(row_type=schema_pb2.RowType(schema=schema)),
- schema_registry)
+ schema_registry=schema_registry)
+ assert isinstance(row_type_constraint, row_type.RowTypeConstraint)
+ return row_type_constraint.user_type
def named_tuple_to_schema(
diff --git a/sdks/python/apache_beam/typehints/schemas_test.py b/sdks/python/apache_beam/typehints/schemas_test.py
index 404d9c5583c..24ec9d11fa5 100644
--- a/sdks/python/apache_beam/typehints/schemas_test.py
+++ b/sdks/python/apache_beam/typehints/schemas_test.py
@@ -30,9 +30,11 @@ from typing import Optional
from typing import Sequence
import numpy as np
+from parameterized import parameterized
from apache_beam.portability import common_urns
from apache_beam.portability.api import schema_pb2
+from apache_beam.typehints import row_type
from apache_beam.typehints.native_type_compatibility import match_is_named_tuple
from apache_beam.typehints.schemas import SchemaTypeRegistry
from apache_beam.typehints.schemas import named_tuple_from_schema
@@ -41,6 +43,59 @@ from apache_beam.typehints.schemas import typing_from_runner_api
from apache_beam.typehints.schemas import typing_to_runner_api
from apache_beam.utils.timestamp import Timestamp
+all_nonoptional_primitives = [
+ np.int8,
+ np.int16,
+ np.int32,
+ np.int64,
+ np.float32,
+ np.float64,
+ bool,
+ bytes,
+ str,
+]
+
+all_optional_primitives = [Optional[typ] for typ in all_nonoptional_primitives]
+
+all_primitives = all_nonoptional_primitives + all_optional_primitives
+
+basic_array_types = [Sequence[typ] for typ in all_primitives]
+
+basic_map_types = [
+ Mapping[key_type, value_type] for key_type,
+ value_type in itertools.product(all_primitives, all_primitives)
+]
+
+
+class AllPrimitives(NamedTuple):
+ field_int8: np.int8
+ field_int16: np.int16
+ field_int32: np.int32
+ field_int64: np.int64
+ field_float32: np.float32
+ field_float64: np.float64
+ field_bool: bool
+ field_bytes: bytes
+ field_str: str
+ field_optional_int8: Optional[np.int8]
+ field_optional_int16: Optional[np.int16]
+ field_optional_int32: Optional[np.int32]
+ field_optional_int64: Optional[np.int64]
+ field_optional_float32: Optional[np.float32]
+ field_optional_float64: Optional[np.float64]
+ field_optional_bool: Optional[bool]
+ field_optional_bytes: Optional[bytes]
+ field_optional_str: Optional[str]
+
+
+class ComplexSchema(NamedTuple):
+ id: np.int64
+ name: str
+ optional_map: Optional[Mapping[str, Optional[np.float64]]]
+ optional_array: Optional[Sequence[np.float32]]
+ array_optional: Sequence[Optional[bool]]
+ timestamp: Timestamp
+
class SchemaTest(unittest.TestCase):
""" Tests for Runner API Schema proto to/from typing conversions
@@ -50,68 +105,27 @@ class SchemaTest(unittest.TestCase):
are cached by ID, so performing just one of them wouldn't necessarily exercise
all code paths.
"""
- def test_typing_survives_proto_roundtrip(self):
- all_nonoptional_primitives = [
- np.int8,
- np.int16,
- np.int32,
- np.int64,
- np.float32,
- np.float64,
- bool,
- bytes,
- str,
- ]
-
- all_optional_primitives = [
- Optional[typ] for typ in all_nonoptional_primitives
- ]
-
- all_primitives = all_nonoptional_primitives + all_optional_primitives
-
- basic_array_types = [Sequence[typ] for typ in all_primitives]
-
- basic_map_types = [
- Mapping[key_type, value_type] for key_type,
- value_type in itertools.product(all_primitives, all_primitives)
- ]
-
- selected_schemas = [
- NamedTuple(
- 'AllPrimitives',
- [('field%d' % i, typ) for i, typ in enumerate(all_primitives)]),
- NamedTuple(
- 'ComplexSchema',
- [
- ('id', np.int64),
- ('name', str),
- ('optional_map', Optional[Mapping[str, Optional[np.float64]]]),
- ('optional_array', Optional[Sequence[np.float32]]),
- ('array_optional', Sequence[Optional[bool]]),
- ('timestamp', Timestamp),
- ])
- ]
+ @parameterized.expand([(user_type,) for user_type in
+ all_primitives + \
+ basic_array_types + \
+ basic_map_types]
+ )
+ def test_typing_survives_proto_roundtrip(self, user_type):
+ self.assertEqual(
+ user_type,
+ typing_from_runner_api(
+ typing_to_runner_api(
+ user_type, schema_registry=SchemaTypeRegistry()),
+ schema_registry=SchemaTypeRegistry()))
- test_cases = all_primitives + \
- basic_array_types + \
- basic_map_types
+ @parameterized.expand([(AllPrimitives, ), (ComplexSchema, )])
+ def test_namedtuple_roundtrip(self, user_type):
+ roundtripped = typing_from_runner_api(
+ typing_to_runner_api(user_type, schema_registry=SchemaTypeRegistry()),
+ schema_registry=SchemaTypeRegistry())
- for test_case in test_cases:
- self.assertEqual(
- test_case,
- typing_from_runner_api(
- typing_to_runner_api(
- test_case, schema_registry=SchemaTypeRegistry()),
- schema_registry=SchemaTypeRegistry()))
-
- # Break out NamedTuple types since they require special verification
- for test_case in selected_schemas:
- self.assert_namedtuple_equivalent(
- test_case,
- typing_from_runner_api(
- typing_to_runner_api(
- test_case, schema_registry=SchemaTypeRegistry()),
- schema_registry=SchemaTypeRegistry()))
+ self.assertIsInstance(roundtripped, row_type.RowTypeConstraint)
+ self.assert_namedtuple_equivalent(roundtripped.user_type, user_type)
def assert_namedtuple_equivalent(self, actual, expected):
# Two types are only considered equal if they are literally the same
@@ -124,9 +138,16 @@ class SchemaTest(unittest.TestCase):
self.assertTrue(match_is_named_tuple(expected))
self.assertTrue(match_is_named_tuple(actual))
+ # TODO(https://github.com/apache/beam/issues/22082): This will break for
+ # nested complex types.
self.assertEqual(actual.__annotations__, expected.__annotations__)
- self.assertEqual(dir(actual), dir(expected))
+ # TODO(https://github.com/apache/beam/issues/22082): Serialize user_type and
+ # re-hydrate with sdk_options to make these checks pass.
+ #self.assertEqual(dir(actual), dir(expected))
+ #
+ #for attr in dir(expected):
+ # self.assertEqual(getattr(actual, attr), getattr(expected, attr))
def test_proto_survives_typing_roundtrip(self):
all_nonoptional_primitives = [
@@ -253,7 +274,8 @@ class SchemaTest(unittest.TestCase):
schema_pb2.FieldType(
logical_type=schema_pb2.LogicalType(
urn=common_urns.python_callable.urn,
- representation=typing_to_runner_api(str)))),
+ representation=typing_to_runner_api(str))),
+ schema_registry=SchemaTypeRegistry()),
PythonCallableWithSource)
def test_trivial_example(self):
diff --git a/sdks/python/apache_beam/typehints/trivial_inference.py b/sdks/python/apache_beam/typehints/trivial_inference.py
index 14df8e79439..4de89ffd83b 100644
--- a/sdks/python/apache_beam/typehints/trivial_inference.py
+++ b/sdks/python/apache_beam/typehints/trivial_inference.py
@@ -49,7 +49,7 @@ def instance_to_type(o):
if o is None:
return type(None)
elif t == pvalue.Row:
- return row_type.RowTypeConstraint([
+ return row_type.RowTypeConstraint.from_fields([
(name, instance_to_type(value)) for name, value in o.as_dict().items()
])
elif t not in typehints.DISALLOWED_PRIMITIVE_TYPES:
@@ -438,8 +438,11 @@ def infer_return_type_func(f, input_types, debug=False, depth=0):
from apache_beam.pvalue import Row
if state.stack[-pop_count].value == Row:
fields = state.stack[-1].value
- return_type = row_type.RowTypeConstraint(
- zip(fields, Const.unwrap_all(state.stack[-pop_count + 1:-1])))
+ return_type = row_type.RowTypeConstraint.from_fields(
+ list(
+ zip(
+ fields,
+ Const.unwrap_all(state.stack[-pop_count + 1:-1]))))
else:
return_type = Any
else: