You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by bh...@apache.org on 2022/06/29 21:18:27 UTC

[beam] branch master updated: Test and fix FlatMap() issue (#22104)

This is an automated email from the ASF dual-hosted git repository.

bhulette pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new ca112ad949f Test and fix FlatMap(<builtin>) issue (#22104)
ca112ad949f is described below

commit ca112ad949f52cde0c8c0fdc7074bf4a7e4c72f1
Author: Brian Hulette <bh...@google.com>
AuthorDate: Wed Jun 29 14:18:21 2022 -0700

    Test and fix FlatMap(<builtin>) issue (#22104)
    
    * 22091: Add tests of and Map, FlatMap, and Filter with builtins
    
    * 22091: Modify _process_batch_defiend and _process_defined to handle builtins
---
 sdks/python/apache_beam/transforms/core.py          | 21 ++++++++++++++-------
 .../apache_beam/transforms/ptransform_test.py       | 21 +++++++++++++++++++++
 2 files changed, 35 insertions(+), 7 deletions(-)

diff --git a/sdks/python/apache_beam/transforms/core.py b/sdks/python/apache_beam/transforms/core.py
index 5fd561ee780..0a85135204e 100644
--- a/sdks/python/apache_beam/transforms/core.py
+++ b/sdks/python/apache_beam/transforms/core.py
@@ -726,16 +726,23 @@ class DoFn(WithTypeHints, HasDisplayData, urns.RunnerApiFn):
 
   @property
   def _process_defined(self) -> bool:
-    return (
-        self.process.__func__  # type: ignore
-        if hasattr(self.process, '__self__') else self.process) != DoFn.process
+    # Check if this DoFn's process method has heen overriden
+    # Note that we retrieve the __func__ attribute, if it exists, to get the
+    # underlying function from the bound method.
+    # If __func__ doesn't exist, self.process was likely overriden with a free
+    # function, as in CallableWrapperDoFn.
+    return getattr(self.process, '__func__', self.process) != DoFn.process
 
   @property
   def _process_batch_defined(self) -> bool:
-    return (
-        self.process_batch.__func__  # type: ignore
-        if hasattr(self.process_batch, '__self__')
-        else self.process_batch) != DoFn.process_batch
+    # Check if this DoFn's process_batch method has heen overriden
+    # Note that we retrieve the __func__ attribute, if it exists, to get the
+    # underlying function from the bound method.
+    # If __func__ doesn't exist, self.process_batch was likely overriden with
+    # a free function.
+    return getattr(
+        self.process_batch, '__func__',
+        self.process_batch) != DoFn.process_batch
 
   @property
   def _can_yield_batches(self) -> bool:
diff --git a/sdks/python/apache_beam/transforms/ptransform_test.py b/sdks/python/apache_beam/transforms/ptransform_test.py
index 9870ff7ed46..2dbe55ceb65 100644
--- a/sdks/python/apache_beam/transforms/ptransform_test.py
+++ b/sdks/python/apache_beam/transforms/ptransform_test.py
@@ -32,6 +32,7 @@ from typing import Optional
 from unittest.mock import patch
 
 import hamcrest as hc
+import numpy as np
 import pytest
 from parameterized import parameterized_class
 
@@ -380,6 +381,26 @@ class PTransformTest(unittest.TestCase):
       with TestPipeline() as p:
         p | 'Start' >> beam.Create([1, 2, 3]) | 'Do' >> beam.ParDo(MyDoFn())
 
+  def test_map_builtin(self):
+    with TestPipeline() as pipeline:
+      pcoll = pipeline | 'Start' >> beam.Create([[1, 2], [1], [1, 2, 3]])
+      result = pcoll | beam.Map(len)
+      assert_that(result, equal_to([1, 2, 3]))
+
+  def test_flatmap_builtin(self):
+    with TestPipeline() as pipeline:
+      pcoll = pipeline | 'Start' >> beam.Create([
+          [np.array([1, 2, 3])] * 3, [np.array([5, 4, 3]), np.array([5, 6, 7])]
+      ])
+      result = pcoll | beam.FlatMap(sum)
+      assert_that(result, equal_to([3, 6, 9, 10, 10, 10]))
+
+  def test_filter_builtin(self):
+    with TestPipeline() as pipeline:
+      pcoll = pipeline | 'Start' >> beam.Create([[], [2], [], [4]])
+      result = pcoll | 'Filter' >> beam.Filter(len)
+      assert_that(result, equal_to([[2], [4]]))
+
   def test_filter(self):
     with TestPipeline() as pipeline:
       pcoll = pipeline | 'Start' >> beam.Create([1, 2, 3, 4])