You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by bh...@apache.org on 2022/07/23 00:48:06 UTC

[beam] branch master updated: Update element_type inference (default_type_hints) for batched DoFns with yields_batches/yields_elements (#22198)

This is an automated email from the ASF dual-hosted git repository.

bhulette pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new d050a088a3d Update element_type inference (default_type_hints) for batched DoFns with yields_batches/yields_elements   (#22198)
d050a088a3d is described below

commit d050a088a3de510d450b0ec6cfbd4e11b277b3ac
Author: Brian Hulette <bh...@google.com>
AuthorDate: Fri Jul 22 17:47:59 2022 -0700

    Update element_type inference (default_type_hints) for batched DoFns with yields_batches/yields_elements   (#22198)
    
    * Add tests checking element_type for batched DoFns
    
    * Fix element type inference for process method with yields_batches
    
    * fixup! Add tests checking element_type for batched DoFns
    
    * fixup! Fix element type inference for process method with yields_batches
    
    * fixup! Add tests checking element_type for batched DoFns
    
    * fixup! Add tests checking element_type for batched DoFns
    
    * fixup! Add tests checking element_type for batched DoFns
    
    * Improve clarity in default_type_hints
---
 .../apache_beam/transforms/batch_dofn_test.py      | 73 +++++++++++++++++++++-
 sdks/python/apache_beam/transforms/core.py         | 44 ++++++++++---
 sdks/python/apache_beam/typehints/decorators.py    | 10 +++
 .../apache_beam/typehints/decorators_test.py       | 40 ++++++++++++
 4 files changed, 159 insertions(+), 8 deletions(-)

diff --git a/sdks/python/apache_beam/transforms/batch_dofn_test.py b/sdks/python/apache_beam/transforms/batch_dofn_test.py
index eb4e6ff0cab..b75f447d8b4 100644
--- a/sdks/python/apache_beam/transforms/batch_dofn_test.py
+++ b/sdks/python/apache_beam/transforms/batch_dofn_test.py
@@ -71,6 +71,9 @@ class ElementToBatchDoFn(beam.DoFn):
   def process(self, element: int, *args, **kwargs) -> Iterator[List[int]]:
     yield [element] * element
 
+  def infer_output_type(self, input_element_type):
+    return input_element_type
+
 
 class BatchToElementDoFn(beam.DoFn):
   @beam.DoFn.yields_elements
@@ -170,6 +173,31 @@ class BatchDoFnNoInputAnnotation(beam.DoFn):
     yield [element * 2 for element in batch]
 
 
+class MismatchedBatchProducingDoFn(beam.DoFn):
+  """A DoFn that produces batches from both process and process_batch, with
+  mismatched return types (one yields floats, the other ints). Should yield
+  a construction time error when applied."""
+  @beam.DoFn.yields_batches
+  def process(self, element: int, *args, **kwargs) -> Iterator[List[int]]:
+    yield [element]
+
+  def process_batch(self, batch: List[int], *args,
+                    **kwargs) -> Iterator[List[float]]:
+    yield [element / 2 for element in batch]
+
+
+class MismatchedElementProducingDoFn(beam.DoFn):
+  """A DoFn that produces elements from both process and process_batch, with
+  mismatched return types (one yields floats, the other ints). Should yield
+  a construction time error when applied."""
+  def process(self, element: int, *args, **kwargs) -> Iterator[float]:
+    yield element / 2
+
+  @beam.DoFn.yields_elements
+  def process_batch(self, batch: List[int], *args, **kwargs) -> Iterator[int]:
+    yield batch[0]
+
+
 class BatchDoFnTest(unittest.TestCase):
   def test_map_pardo(self):
     # verify batch dofn accessors work well with beam.Map generated DoFn
@@ -199,9 +227,52 @@ class BatchDoFnTest(unittest.TestCase):
     pc = p | beam.Create([1, 2, 3])
 
     with self.assertRaisesRegex(NotImplementedError,
-                                r'.*BatchDoFnBadParam.*KeyParam'):
+                                r'BatchDoFnBadParam.*KeyParam'):
       _ = pc | beam.ParDo(BatchDoFnBadParam())
 
+  def test_mismatched_batch_producer_raises(self):
+    p = beam.Pipeline()
+    pc = p | beam.Create([1, 2, 3])
+
+    # Note (?ms) makes this a multiline regex, where . matches newlines.
+    # See (?aiLmsux) at
+    # https://docs.python.org/3.4/library/re.html#regular-expression-syntax
+    with self.assertRaisesRegex(
+        TypeError,
+        (r'(?ms)MismatchedBatchProducingDoFn.*'
+         r'process: List\[int\].*process_batch: List\[float\]')):
+      _ = pc | beam.ParDo(MismatchedBatchProducingDoFn())
+
+  def test_mismatched_element_producer_raises(self):
+    p = beam.Pipeline()
+    pc = p | beam.Create([1, 2, 3])
+
+    # Note (?ms) makes this a multiline regex, where . matches newlines.
+    # See (?aiLmsux) at
+    # https://docs.python.org/3.4/library/re.html#regular-expression-syntax
+    with self.assertRaisesRegex(
+        TypeError,
+        r'(?ms)MismatchedElementProducingDoFn.*process:.*process_batch:'):
+      _ = pc | beam.ParDo(MismatchedElementProducingDoFn())
+
+  def test_element_to_batch_dofn_typehint(self):
+    # Verify that element to batch DoFn sets the correct typehint on the output
+    # PCollection.
+
+    p = beam.Pipeline()
+    pc = (p | beam.Create([1, 2, 3]) | beam.ParDo(ElementToBatchDoFn()))
+
+    self.assertEqual(pc.element_type, int)
+
+  def test_batch_to_element_dofn_typehint(self):
+    # Verify that batch to element DoFn sets the correct typehint on the output
+    # PCollection.
+
+    p = beam.Pipeline()
+    pc = (p | beam.Create([1, 2, 3]) | beam.ParDo(BatchToElementDoFn()))
+
+    self.assertEqual(pc.element_type, beam.typehints.Tuple[int, int])
+
 
 if __name__ == '__main__':
   unittest.main()
diff --git a/sdks/python/apache_beam/transforms/core.py b/sdks/python/apache_beam/transforms/core.py
index dbf683a019b..51692896740 100644
--- a/sdks/python/apache_beam/transforms/core.py
+++ b/sdks/python/apache_beam/transforms/core.py
@@ -708,14 +708,44 @@ class DoFn(WithTypeHints, HasDisplayData, urns.RunnerApiFn):
     return get_function_arguments(self, func)
 
   def default_type_hints(self):
-    fn_type_hints = typehints.decorators.IOTypeHints.from_callable(self.process)
-    if fn_type_hints is not None:
-      try:
-        fn_type_hints = fn_type_hints.strip_iterable()
-      except ValueError as e:
-        raise ValueError('Return value not iterable: %s: %s' % (self, e))
+    process_type_hints = typehints.decorators.IOTypeHints.from_callable(
+        self.process) or typehints.decorators.IOTypeHints.empty()
+
+    if self._process_yields_batches:
+      # process() produces batches, don't use it's output typehint
+      process_type_hints = process_type_hints.with_output_types_from(
+          typehints.decorators.IOTypeHints.empty())
+
+    if self._process_batch_yields_elements:
+      # process_batch() produces elements, *do* use it's output typehint
+
+      # First access the typehint
+      process_batch_type_hints = typehints.decorators.IOTypeHints.from_callable(
+          self.process_batch) or typehints.decorators.IOTypeHints.empty()
+
+      # Then we deconflict with the typehint from process, if it exists
+      if (process_batch_type_hints.output_types !=
+          typehints.decorators.IOTypeHints.empty().output_types):
+        if (process_type_hints.output_types !=
+            typehints.decorators.IOTypeHints.empty().output_types and
+            process_batch_type_hints.output_types !=
+            process_type_hints.output_types):
+          raise TypeError(
+              f"DoFn {self!r} yields element from both process and "
+              "process_batch, but they have mismatched output typehints:\n"
+              f" process: {process_type_hints.output_types}\n"
+              f" process_batch: {process_batch_type_hints.output_types}")
+
+        process_type_hints = process_type_hints.with_output_types_from(
+            process_batch_type_hints)
+
+    try:
+      process_type_hints = process_type_hints.strip_iterable()
+    except ValueError as e:
+      raise ValueError('Return value not iterable: %s: %s' % (self, e))
+
     # Prefer class decorator type hints for backwards compatibility.
-    return get_type_hints(self.__class__).with_defaults(fn_type_hints)
+    return get_type_hints(self.__class__).with_defaults(process_type_hints)
 
   # TODO(sourabhbajaj): Do we want to remove the responsibility of these from
   # the DoFn or maybe the runner
diff --git a/sdks/python/apache_beam/typehints/decorators.py b/sdks/python/apache_beam/typehints/decorators.py
index 1b6fe701416..c24f2ed8f43 100644
--- a/sdks/python/apache_beam/typehints/decorators.py
+++ b/sdks/python/apache_beam/typehints/decorators.py
@@ -302,6 +302,16 @@ class IOTypeHints(NamedTuple):
     return self._replace(
         output_types=(args, kwargs), origin=self._make_origin([self]))
 
+  def with_input_types_from(self, other):
+    # type: (IOTypeHints) -> IOTypeHints
+    return self._replace(
+        input_types=other.input_types, origin=self._make_origin([self]))
+
+  def with_output_types_from(self, other):
+    # type: (IOTypeHints) -> IOTypeHints
+    return self._replace(
+        output_types=other.output_types, origin=self._make_origin([self]))
+
   def simple_output_type(self, context):
     if self._has_output_types():
       args, kwargs = self.output_types
diff --git a/sdks/python/apache_beam/typehints/decorators_test.py b/sdks/python/apache_beam/typehints/decorators_test.py
index 9b109969bf3..ba46038e472 100644
--- a/sdks/python/apache_beam/typehints/decorators_test.py
+++ b/sdks/python/apache_beam/typehints/decorators_test.py
@@ -97,6 +97,46 @@ class IOTypeHintsTest(unittest.TestCase):
     after = th.strip_iterable()
     self.assertEqual(((expected_after, ), {}), after.output_types)
 
+  def test_with_output_types_from(self):
+    th = decorators.IOTypeHints(
+        input_types=((int), {
+            'foo': str
+        }),
+        output_types=((int, str), {}),
+        origin=[])
+
+    self.assertEqual(
+        th.with_output_types_from(decorators.IOTypeHints.empty()),
+        decorators.IOTypeHints(
+            input_types=((int), {
+                'foo': str
+            }), output_types=None, origin=[]))
+
+    self.assertEqual(
+        decorators.IOTypeHints.empty().with_output_types_from(th),
+        decorators.IOTypeHints(
+            input_types=None, output_types=((int, str), {}), origin=[]))
+
+  def test_with_input_types_from(self):
+    th = decorators.IOTypeHints(
+        input_types=((int), {
+            'foo': str
+        }),
+        output_types=((int, str), {}),
+        origin=[])
+
+    self.assertEqual(
+        th.with_input_types_from(decorators.IOTypeHints.empty()),
+        decorators.IOTypeHints(
+            input_types=None, output_types=((int, str), {}), origin=[]))
+
+    self.assertEqual(
+        decorators.IOTypeHints.empty().with_input_types_from(th),
+        decorators.IOTypeHints(
+            input_types=((int), {
+                'foo': str
+            }), output_types=None, origin=[]))
+
   def _test_strip_iterable_fail(self, before):
     with self.assertRaisesRegex(ValueError, r'not iterable'):
       self._test_strip_iterable(before, None)