You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by bh...@apache.org on 2022/07/23 00:48:06 UTC
[beam] branch master updated: Update element_type inference (default_type_hints) for batched DoFns with yields_batches/yields_elements (#22198)
This is an automated email from the ASF dual-hosted git repository.
bhulette pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git
The following commit(s) were added to refs/heads/master by this push:
new d050a088a3d Update element_type inference (default_type_hints) for batched DoFns with yields_batches/yields_elements (#22198)
d050a088a3d is described below
commit d050a088a3de510d450b0ec6cfbd4e11b277b3ac
Author: Brian Hulette <bh...@google.com>
AuthorDate: Fri Jul 22 17:47:59 2022 -0700
Update element_type inference (default_type_hints) for batched DoFns with yields_batches/yields_elements (#22198)
* Add tests checking element_type for batched DoFns
* Fix element type inference for process method with yields_batches
* fixup! Add tests checking element_type for batched DoFns
* fixup! Fix element type inference for process method with yields_batches
* fixup! Add tests checking element_type for batched DoFns
* fixup! Add tests checking element_type for batched DoFns
* fixup! Add tests checking element_type for batched DoFns
* Improve clarity in default_type_hints
---
.../apache_beam/transforms/batch_dofn_test.py | 73 +++++++++++++++++++++-
sdks/python/apache_beam/transforms/core.py | 44 ++++++++++---
sdks/python/apache_beam/typehints/decorators.py | 10 +++
.../apache_beam/typehints/decorators_test.py | 40 ++++++++++++
4 files changed, 159 insertions(+), 8 deletions(-)
diff --git a/sdks/python/apache_beam/transforms/batch_dofn_test.py b/sdks/python/apache_beam/transforms/batch_dofn_test.py
index eb4e6ff0cab..b75f447d8b4 100644
--- a/sdks/python/apache_beam/transforms/batch_dofn_test.py
+++ b/sdks/python/apache_beam/transforms/batch_dofn_test.py
@@ -71,6 +71,9 @@ class ElementToBatchDoFn(beam.DoFn):
def process(self, element: int, *args, **kwargs) -> Iterator[List[int]]:
yield [element] * element
+ def infer_output_type(self, input_element_type):
+ return input_element_type
+
class BatchToElementDoFn(beam.DoFn):
@beam.DoFn.yields_elements
@@ -170,6 +173,31 @@ class BatchDoFnNoInputAnnotation(beam.DoFn):
yield [element * 2 for element in batch]
+class MismatchedBatchProducingDoFn(beam.DoFn):
+ """A DoFn that produces batches from both process and process_batch, with
+ mismatched return types (one yields floats, the other ints). Should yield
+ a construction time error when applied."""
+ @beam.DoFn.yields_batches
+ def process(self, element: int, *args, **kwargs) -> Iterator[List[int]]:
+ yield [element]
+
+ def process_batch(self, batch: List[int], *args,
+ **kwargs) -> Iterator[List[float]]:
+ yield [element / 2 for element in batch]
+
+
+class MismatchedElementProducingDoFn(beam.DoFn):
+ """A DoFn that produces elements from both process and process_batch, with
+ mismatched return types (one yields floats, the other ints). Should yield
+ a construction time error when applied."""
+ def process(self, element: int, *args, **kwargs) -> Iterator[float]:
+ yield element / 2
+
+ @beam.DoFn.yields_elements
+ def process_batch(self, batch: List[int], *args, **kwargs) -> Iterator[int]:
+ yield batch[0]
+
+
class BatchDoFnTest(unittest.TestCase):
def test_map_pardo(self):
# verify batch dofn accessors work well with beam.Map generated DoFn
@@ -199,9 +227,52 @@ class BatchDoFnTest(unittest.TestCase):
pc = p | beam.Create([1, 2, 3])
with self.assertRaisesRegex(NotImplementedError,
- r'.*BatchDoFnBadParam.*KeyParam'):
+ r'BatchDoFnBadParam.*KeyParam'):
_ = pc | beam.ParDo(BatchDoFnBadParam())
+ def test_mismatched_batch_producer_raises(self):
+ p = beam.Pipeline()
+ pc = p | beam.Create([1, 2, 3])
+
+ # Note (?ms) makes this a multiline regex, where . matches newlines.
+ # See (?aiLmsux) at
+ # https://docs.python.org/3.4/library/re.html#regular-expression-syntax
+ with self.assertRaisesRegex(
+ TypeError,
+ (r'(?ms)MismatchedBatchProducingDoFn.*'
+ r'process: List\[int\].*process_batch: List\[float\]')):
+ _ = pc | beam.ParDo(MismatchedBatchProducingDoFn())
+
+ def test_mismatched_element_producer_raises(self):
+ p = beam.Pipeline()
+ pc = p | beam.Create([1, 2, 3])
+
+ # Note (?ms) makes this a multiline regex, where . matches newlines.
+ # See (?aiLmsux) at
+ # https://docs.python.org/3.4/library/re.html#regular-expression-syntax
+ with self.assertRaisesRegex(
+ TypeError,
+ r'(?ms)MismatchedElementProducingDoFn.*process:.*process_batch:'):
+ _ = pc | beam.ParDo(MismatchedElementProducingDoFn())
+
+ def test_element_to_batch_dofn_typehint(self):
+ # Verify that element to batch DoFn sets the correct typehint on the output
+ # PCollection.
+
+ p = beam.Pipeline()
+ pc = (p | beam.Create([1, 2, 3]) | beam.ParDo(ElementToBatchDoFn()))
+
+ self.assertEqual(pc.element_type, int)
+
+ def test_batch_to_element_dofn_typehint(self):
+ # Verify that batch to element DoFn sets the correct typehint on the output
+ # PCollection.
+
+ p = beam.Pipeline()
+ pc = (p | beam.Create([1, 2, 3]) | beam.ParDo(BatchToElementDoFn()))
+
+ self.assertEqual(pc.element_type, beam.typehints.Tuple[int, int])
+
if __name__ == '__main__':
unittest.main()
diff --git a/sdks/python/apache_beam/transforms/core.py b/sdks/python/apache_beam/transforms/core.py
index dbf683a019b..51692896740 100644
--- a/sdks/python/apache_beam/transforms/core.py
+++ b/sdks/python/apache_beam/transforms/core.py
@@ -708,14 +708,44 @@ class DoFn(WithTypeHints, HasDisplayData, urns.RunnerApiFn):
return get_function_arguments(self, func)
def default_type_hints(self):
- fn_type_hints = typehints.decorators.IOTypeHints.from_callable(self.process)
- if fn_type_hints is not None:
- try:
- fn_type_hints = fn_type_hints.strip_iterable()
- except ValueError as e:
- raise ValueError('Return value not iterable: %s: %s' % (self, e))
+ process_type_hints = typehints.decorators.IOTypeHints.from_callable(
+ self.process) or typehints.decorators.IOTypeHints.empty()
+
+ if self._process_yields_batches:
+ # process() produces batches, don't use it's output typehint
+ process_type_hints = process_type_hints.with_output_types_from(
+ typehints.decorators.IOTypeHints.empty())
+
+ if self._process_batch_yields_elements:
+ # process_batch() produces elements, *do* use it's output typehint
+
+ # First access the typehint
+ process_batch_type_hints = typehints.decorators.IOTypeHints.from_callable(
+ self.process_batch) or typehints.decorators.IOTypeHints.empty()
+
+ # Then we deconflict with the typehint from process, if it exists
+ if (process_batch_type_hints.output_types !=
+ typehints.decorators.IOTypeHints.empty().output_types):
+ if (process_type_hints.output_types !=
+ typehints.decorators.IOTypeHints.empty().output_types and
+ process_batch_type_hints.output_types !=
+ process_type_hints.output_types):
+ raise TypeError(
+ f"DoFn {self!r} yields element from both process and "
+ "process_batch, but they have mismatched output typehints:\n"
+ f" process: {process_type_hints.output_types}\n"
+ f" process_batch: {process_batch_type_hints.output_types}")
+
+ process_type_hints = process_type_hints.with_output_types_from(
+ process_batch_type_hints)
+
+ try:
+ process_type_hints = process_type_hints.strip_iterable()
+ except ValueError as e:
+ raise ValueError('Return value not iterable: %s: %s' % (self, e))
+
# Prefer class decorator type hints for backwards compatibility.
- return get_type_hints(self.__class__).with_defaults(fn_type_hints)
+ return get_type_hints(self.__class__).with_defaults(process_type_hints)
# TODO(sourabhbajaj): Do we want to remove the responsibility of these from
# the DoFn or maybe the runner
diff --git a/sdks/python/apache_beam/typehints/decorators.py b/sdks/python/apache_beam/typehints/decorators.py
index 1b6fe701416..c24f2ed8f43 100644
--- a/sdks/python/apache_beam/typehints/decorators.py
+++ b/sdks/python/apache_beam/typehints/decorators.py
@@ -302,6 +302,16 @@ class IOTypeHints(NamedTuple):
return self._replace(
output_types=(args, kwargs), origin=self._make_origin([self]))
+ def with_input_types_from(self, other):
+ # type: (IOTypeHints) -> IOTypeHints
+ return self._replace(
+ input_types=other.input_types, origin=self._make_origin([self]))
+
+ def with_output_types_from(self, other):
+ # type: (IOTypeHints) -> IOTypeHints
+ return self._replace(
+ output_types=other.output_types, origin=self._make_origin([self]))
+
def simple_output_type(self, context):
if self._has_output_types():
args, kwargs = self.output_types
diff --git a/sdks/python/apache_beam/typehints/decorators_test.py b/sdks/python/apache_beam/typehints/decorators_test.py
index 9b109969bf3..ba46038e472 100644
--- a/sdks/python/apache_beam/typehints/decorators_test.py
+++ b/sdks/python/apache_beam/typehints/decorators_test.py
@@ -97,6 +97,46 @@ class IOTypeHintsTest(unittest.TestCase):
after = th.strip_iterable()
self.assertEqual(((expected_after, ), {}), after.output_types)
+ def test_with_output_types_from(self):
+ th = decorators.IOTypeHints(
+ input_types=((int), {
+ 'foo': str
+ }),
+ output_types=((int, str), {}),
+ origin=[])
+
+ self.assertEqual(
+ th.with_output_types_from(decorators.IOTypeHints.empty()),
+ decorators.IOTypeHints(
+ input_types=((int), {
+ 'foo': str
+ }), output_types=None, origin=[]))
+
+ self.assertEqual(
+ decorators.IOTypeHints.empty().with_output_types_from(th),
+ decorators.IOTypeHints(
+ input_types=None, output_types=((int, str), {}), origin=[]))
+
+ def test_with_input_types_from(self):
+ th = decorators.IOTypeHints(
+ input_types=((int), {
+ 'foo': str
+ }),
+ output_types=((int, str), {}),
+ origin=[])
+
+ self.assertEqual(
+ th.with_input_types_from(decorators.IOTypeHints.empty()),
+ decorators.IOTypeHints(
+ input_types=None, output_types=((int, str), {}), origin=[]))
+
+ self.assertEqual(
+ decorators.IOTypeHints.empty().with_input_types_from(th),
+ decorators.IOTypeHints(
+ input_types=((int), {
+ 'foo': str
+ }), output_types=None, origin=[]))
+
def _test_strip_iterable_fail(self, before):
with self.assertRaisesRegex(ValueError, r'not iterable'):
self._test_strip_iterable(before, None)