You are viewing a plain text version of this content. The canonical link for it is here.
Posted to github@beam.apache.org by GitBox <gi...@apache.org> on 2022/06/14 12:15:55 UTC

[GitHub] [beam] yeandy commented on a diff in pull request #21844: Document and test overriding batch type inference

yeandy commented on code in PR #21844:
URL: https://github.com/apache/beam/pull/21844#discussion_r896688952


##########
sdks/python/apache_beam/transforms/core.py:
##########
@@ -770,16 +794,32 @@ def _get_element_type_from_return_annotation(method, input_type):
           f"{method!r}, did you mean Iterator[{return_type}]?")
 
   def get_output_batch_type(
-      self, input_element_type) -> typing.Optional[TypeConstraint]:
+      self, input_element_type
+  ) -> typing.Optional[typing.Union[TypeConstraint, type]]:
+    """Determine the batch type produced by this DoFn's ``process_batch``
+    implementation and/or it's ``process`` implementation with

Review Comment:
   ```suggestion
       implementation and/or its ``process`` implementation with
   ```



##########
sdks/python/apache_beam/runners/portability/fn_api_runner/fn_runner_test.py:
##########
@@ -136,6 +136,37 @@ def test_batch_pardo(self):
 
       assert_that(res, equal_to([6, 12, 18]))
 
+  def test_batch_pardo_override_type_inference(self):
+

Review Comment:
   Remove extra line
   ```suggestion
   ```



##########
sdks/python/apache_beam/transforms/core.py:
##########
@@ -737,7 +737,23 @@ def process_yields_batches(self) -> bool:
   def process_batch_yields_elements(self) -> bool:
     return getattr(self.process_batch, '_beam_yields_elements', False)
 
-  def get_input_batch_type(self) -> typing.Optional[TypeConstraint]:
+  def get_input_batch_type(
+      self, input_element_type
+  ) -> typing.Optional[typing.Union[TypeConstraint, type]]:
+    """Determine the batch type expected as input to process_batch.
+
+    The default implementation of ``get_input_batch_type`` simply observes the
+    input typehint for the first parameter of ``process_batch``. A Batched DoFn
+    may override this method if a dynamic approach is required.
+
+    Args:
+      input_element_type: The **element type** of the input PCollection this
+        DoFn is being applied to.
+
+    Returns:
+      ``None`` if this DoFn cannot accept batches, a Beam typehint or a native

Review Comment:
   ```suggestion
         ``None`` if this DoFn cannot accept batches, a Beam typehint, or a native
   ```



##########
sdks/python/apache_beam/transforms/batch_dofn_test.py:
##########
@@ -46,6 +46,17 @@ def process_batch(self, batch: List[int], *args, **kwargs):
     yield [element * 2 for element in batch]
 
 
+class BatchDoFnOverrideTypeInference(beam.DoFn):
+  def process_batch(self, batch, *args, **kwargs):
+    yield [element * 2 for element in batch]
+
+  def get_input_batch_type(self, input_element_type):
+    return List[input_element_type]
+
+  def get_output_batch_type(self, input_element_type):
+    return List[input_element_type]

Review Comment:
   Would it make more sense for users to call `self.get_input_batch_type(input_element_type)` instead of repeating `List[input_element_type]`?



##########
sdks/python/apache_beam/runners/portability/fn_api_runner/fn_runner_test.py:
##########
@@ -136,6 +136,37 @@ def test_batch_pardo(self):
 
       assert_that(res, equal_to([6, 12, 18]))
 
+  def test_batch_pardo_override_type_inference(self):
+
+    class ArrayMultiplyTransposedDoFn(beam.DoFn):

Review Comment:
   I might be misunderstanding something, but why is "Transposed" in the name?



##########
sdks/python/apache_beam/transforms/core.py:
##########
@@ -746,10 +762,18 @@ def get_input_batch_type(self) -> typing.Optional[TypeConstraint]:
       # TODO(BEAM-14340): Consider supporting an alternative (dynamic?) approach
       # for declaring input type
       raise TypeError(
-          f"{self.__class__.__name__}.process_batch() does not have a type "
-          "annotation on its first parameter. This is required for "
-          "process_batch implementations.")
-    return typehints.native_type_compatibility.convert_to_beam_type(input_type)
+          f"Either {self.__class__.__name__}.process_batch() must have a type "
+          f"annotation on its first parameter, or {self.__class__.__name__} "
+          "must override get_input_batch_type.")
+    return input_type
+
+  def _get_input_batch_type_normalized(self, input_element_type):
+    return typehints.native_type_compatibility.convert_to_beam_type(
+        self.get_input_batch_type(input_element_type))
+
+  def _get_output_batch_type_normalized(self, input_element_type):
+    return typehints.native_type_compatibility.convert_to_beam_type(
+        self.get_output_batch_type(input_element_type))

Review Comment:
   Why are these private functions? Is it because normalizing to Beam types isn't going to be a common op?



##########
sdks/python/apache_beam/transforms/core.py:
##########
@@ -770,16 +794,32 @@ def _get_element_type_from_return_annotation(method, input_type):
           f"{method!r}, did you mean Iterator[{return_type}]?")
 
   def get_output_batch_type(
-      self, input_element_type) -> typing.Optional[TypeConstraint]:
+      self, input_element_type
+  ) -> typing.Optional[typing.Union[TypeConstraint, type]]:
+    """Determine the batch type produced by this DoFn's ``process_batch``
+    implementation and/or it's ``process`` implementation with
+    ``@yields_batch``.
+
+    The default implementation of this method observes the return type
+    annotations on ``process_batch`` and/or ``process``.  A Batched DoFn may
+    override this method if a dynamic approach is required.
+
+    Args:
+      input_element_type: The **element type** of the input PCollection this
+        DoFn is being applied to.
+
+    Returns:
+      ``None`` if this DoFn will never yield batches, a Beam typehint or

Review Comment:
   ```suggestion
         ``None`` if this DoFn will never yield batches, a Beam typehint, or
   ```



##########
sdks/python/apache_beam/runners/portability/fn_api_runner/fn_runner_test.py:
##########
@@ -136,6 +136,37 @@ def test_batch_pardo(self):
 
       assert_that(res, equal_to([6, 12, 18]))
 
+  def test_batch_pardo_override_type_inference(self):
+
+    class ArrayMultiplyTransposedDoFn(beam.DoFn):
+

Review Comment:
   Remove extra line
   ```suggestion
   ```



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscribe@beam.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org