You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by al...@apache.org on 2019/05/22 01:01:13 UTC

[beam] branch master updated: [BEAM-6695] Latest PTransform for Python SDK (#8206)

This is an automated email from the ASF dual-hosted git repository.

altay pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new 766a765  [BEAM-6695] Latest PTransform for Python SDK (#8206)
766a765 is described below

commit 766a76538d806c225574e8eabd7b25a3f4cf3e08
Author: Tanay Tummalapalli <tt...@gmail.com>
AuthorDate: Wed May 22 06:30:53 2019 +0530

    [BEAM-6695] Latest PTransform for Python SDK (#8206)
    
    * [BEAM-6695] Latest PTransform for Python SDK
    
    Added Latest PTransform and Combine Fns for the Python SDK.
    Latest PTransform is used to compute the element(s) with the
    latest timestamp from a PCollection.
---
 sdks/python/apache_beam/transforms/combiners.py    | 66 ++++++++++++++++
 .../apache_beam/transforms/combiners_test.py       | 89 ++++++++++++++++++++++
 2 files changed, 155 insertions(+)

diff --git a/sdks/python/apache_beam/transforms/combiners.py b/sdks/python/apache_beam/transforms/combiners.py
index e8345a1..94a67fa 100644
--- a/sdks/python/apache_beam/transforms/combiners.py
+++ b/sdks/python/apache_beam/transforms/combiners.py
@@ -45,6 +45,8 @@ from apache_beam.typehints import TypeVariable
 from apache_beam.typehints import Union
 from apache_beam.typehints import with_input_types
 from apache_beam.typehints import with_output_types
+from apache_beam.utils.timestamp import Duration
+from apache_beam.utils.timestamp import Timestamp
 
 __all__ = [
     'Count',
@@ -53,12 +55,14 @@ __all__ = [
     'Top',
     'ToDict',
     'ToList',
+    'Latest'
     ]
 
 # Type variables
 T = TypeVariable('T')
 K = TypeVariable('K')
 V = TypeVariable('V')
+TimestampType = Union[int, long, float, Timestamp, Duration]
 
 
 class Mean(object):
@@ -858,3 +862,65 @@ class PhasedCombineFnExecutor(object):
 
   def extract_only(self, accumulator):
     return self.combine_fn.extract_output(accumulator)
+
+
+class Latest(object):
+  """Combiners for computing the latest element"""
+
+  @with_input_types(T)
+  @with_output_types(T)
+  class Globally(ptransform.PTransform):
+    """Compute the element with the latest timestamp from a
+    PCollection."""
+
+    @staticmethod
+    def add_timestamp(element, timestamp=core.DoFn.TimestampParam):
+      return [(element, timestamp)]
+
+    def expand(self, pcoll):
+      return (pcoll
+              | core.ParDo(self.add_timestamp)
+              .with_output_types(Tuple[T, TimestampType])
+              | core.CombineGlobally(LatestCombineFn()))
+
+  @with_input_types(KV[K, V])
+  @with_output_types(KV[K, V])
+  class PerKey(ptransform.PTransform):
+    """Compute elements with the latest timestamp for each key
+    from a keyed PCollection"""
+
+    @staticmethod
+    def add_timestamp(element, timestamp=core.DoFn.TimestampParam):
+      key, value = element
+      return [(key, (value, timestamp))]
+
+    def expand(self, pcoll):
+      return (pcoll
+              | core.ParDo(self.add_timestamp)
+              .with_output_types(KV[K, Tuple[T, TimestampType]])
+              | core.CombinePerKey(LatestCombineFn()))
+
+
+@with_input_types(Tuple[T, TimestampType])
+@with_output_types(T)
+class LatestCombineFn(core.CombineFn):
+  """CombineFn to get the element with the latest timestamp
+  from a PCollection."""
+
+  def create_accumulator(self):
+    return (None, window.MIN_TIMESTAMP)
+
+  def add_input(self, accumulator, element):
+    if accumulator[1] > element[1]:
+      return accumulator
+    else:
+      return element
+
+  def merge_accumulators(self, accumulators):
+    result = self.create_accumulator()
+    for accumulator in accumulators:
+      result = self.add_input(result, accumulator)
+    return result
+
+  def extract_output(self, accumulator):
+    return accumulator[0]
diff --git a/sdks/python/apache_beam/transforms/combiners_test.py b/sdks/python/apache_beam/transforms/combiners_test.py
index a73fbac..8526825 100644
--- a/sdks/python/apache_beam/transforms/combiners_test.py
+++ b/sdks/python/apache_beam/transforms/combiners_test.py
@@ -32,12 +32,14 @@ import apache_beam.transforms.combiners as combine
 from apache_beam.testing.test_pipeline import TestPipeline
 from apache_beam.testing.util import assert_that
 from apache_beam.testing.util import equal_to
+from apache_beam.transforms import window
 from apache_beam.transforms.core import CombineGlobally
 from apache_beam.transforms.core import Create
 from apache_beam.transforms.core import Map
 from apache_beam.transforms.display import DisplayData
 from apache_beam.transforms.display_test import DisplayDataItemMatcher
 from apache_beam.transforms.ptransform import PTransform
+from apache_beam.typehints import TypeCheckError
 
 
 class CombineTest(unittest.TestCase):
@@ -392,5 +394,92 @@ class CombineTest(unittest.TestCase):
       assert_that(result, equal_to([49.5]))
 
 
+class LatestTest(unittest.TestCase):
+
+  def test_globally(self):
+    l = [window.TimestampedValue(3, 100),
+         window.TimestampedValue(1, 200),
+         window.TimestampedValue(2, 300)]
+    with TestPipeline() as p:
+      # Map(lambda x: x) PTransform is added after Create here, because when
+      # a PCollection of TimestampedValues is created with Create PTransform,
+      # the timestamps are not assigned to it. Adding a Map forces the
+      # PCollection to go through a DoFn so that the PCollection consists of
+      # the elements with timestamps assigned to them instead of a PCollection
+      # of TimestampedValue(element, timestamp).
+      pc = p | Create(l) | Map(lambda x: x)
+      latest = pc | combine.Latest.Globally()
+      assert_that(latest, equal_to([2]))
+
+  def test_globally_empty(self):
+    l = []
+    with TestPipeline() as p:
+      pc = p | Create(l) | Map(lambda x: x)
+      latest = pc | combine.Latest.Globally()
+      assert_that(latest, equal_to([None]))
+
+  def test_per_key(self):
+    l = [window.TimestampedValue(('a', 1), 300),
+         window.TimestampedValue(('b', 3), 100),
+         window.TimestampedValue(('a', 2), 200)]
+    with TestPipeline() as p:
+      pc = p | Create(l) | Map(lambda x: x)
+      latest = pc | combine.Latest.PerKey()
+      assert_that(latest, equal_to([('a', 1), ('b', 3)]))
+
+  def test_per_key_empty(self):
+    l = []
+    with TestPipeline() as p:
+      pc = p | Create(l) | Map(lambda x: x)
+      latest = pc | combine.Latest.PerKey()
+      assert_that(latest, equal_to([]))
+
+
+class LatestCombineFnTest(unittest.TestCase):
+
+  def setUp(self):
+    self.fn = combine.LatestCombineFn()
+
+  def test_create_accumulator(self):
+    accumulator = self.fn.create_accumulator()
+    self.assertEquals(accumulator, (None, window.MIN_TIMESTAMP))
+
+  def test_add_input(self):
+    accumulator = self.fn.create_accumulator()
+    element = (1, 100)
+    new_accumulator = self.fn.add_input(accumulator, element)
+    self.assertEquals(new_accumulator, (1, 100))
+
+  def test_merge_accumulators(self):
+    accumulators = [(2, 400), (5, 100), (9, 200)]
+    merged_accumulator = self.fn.merge_accumulators(accumulators)
+    self.assertEquals(merged_accumulator, (2, 400))
+
+  def test_extract_output(self):
+    accumulator = (1, 100)
+    output = self.fn.extract_output(accumulator)
+    self.assertEquals(output, 1)
+
+  def test_with_input_types_decorator_violation(self):
+    l_int = [1, 2, 3]
+    l_dict = [{'a': 3}, {'g': 5}, {'r': 8}]
+    l_3_tuple = [(12, 31, 41), (12, 34, 34), (84, 92, 74)]
+
+    with self.assertRaises(TypeCheckError):
+      with TestPipeline() as p:
+        pc = p | Create(l_int)
+        _ = pc | beam.CombineGlobally(self.fn)
+
+    with self.assertRaises(TypeCheckError):
+      with TestPipeline() as p:
+        pc = p | Create(l_dict)
+        _ = pc | beam.CombineGlobally(self.fn)
+
+    with self.assertRaises(TypeCheckError):
+      with TestPipeline() as p:
+        pc = p | Create(l_3_tuple)
+        _ = pc | beam.CombineGlobally(self.fn)
+
+
 if __name__ == '__main__':
   unittest.main()