You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by GitBox <gi...@apache.org> on 2019/01/14 20:03:16 UTC
[beam] Diff for: [GitHub] aaltay merged pull request #7498: Made
SampleFixedSizeGlobally and FixedSizePerKey PTransform classes
diff --git a/sdks/python/apache_beam/transforms/combiners.py b/sdks/python/apache_beam/transforms/combiners.py
index e3df43632b21..dbb143e6e168 100644
--- a/sdks/python/apache_beam/transforms/combiners.py
+++ b/sdks/python/apache_beam/transforms/combiners.py
@@ -522,15 +522,35 @@ class Sample(object):
"""Combiners for sampling n elements without replacement."""
# pylint: disable=no-self-argument
- @staticmethod
- @ptransform.ptransform_fn
- def FixedSizeGlobally(pcoll, n):
- return pcoll | core.CombineGlobally(SampleCombineFn(n))
+ class FixedSizeGlobally(ptransform.PTransform):
+ """Sample n elements from the input PCollection without replacement."""
- @staticmethod
- @ptransform.ptransform_fn
- def FixedSizePerKey(pcoll, n):
- return pcoll | core.CombinePerKey(SampleCombineFn(n))
+ def __init__(self, n):
+ self._n = n
+
+ def expand(self, pcoll):
+ return pcoll | core.CombineGlobally(SampleCombineFn(self._n))
+
+ def display_data(self):
+ return {'n': self._n}
+
+ def default_label(self):
+ return 'FixedSizeGlobally(%d)' % self._n
+
+ class FixedSizePerKey(ptransform.PTransform):
+ """Sample n elements associated with each key without replacement."""
+
+ def __init__(self, n):
+ self._n = n
+
+ def expand(self, pcoll):
+ return pcoll | core.CombinePerKey(SampleCombineFn(self._n))
+
+ def display_data(self):
+ return {'n': self._n}
+
+ def default_label(self):
+ return 'FixedSizePerKey(%d)' % self._n
@with_input_types(T)
diff --git a/sdks/python/apache_beam/transforms/combiners_test.py b/sdks/python/apache_beam/transforms/combiners_test.py
index 637a41f3dcb5..3db019a03599 100644
--- a/sdks/python/apache_beam/transforms/combiners_test.py
+++ b/sdks/python/apache_beam/transforms/combiners_test.py
@@ -177,26 +177,16 @@ def individual_test_per_key_dd(combineFn):
individual_test_per_key_dd(combine.Largest(5))
def test_combine_sample_display_data(self):
- def individual_test_per_key_dd(sampleFn, args, kwargs):
- trs = [sampleFn(*args, **kwargs)]
+ def individual_test_per_key_dd(sampleFn, n):
+ trs = [sampleFn(n)]
for transform in trs:
dd = DisplayData.create_from(transform)
- expected_items = [
- DisplayDataItemMatcher('fn', transform._fn.__name__)]
- if args:
- expected_items.append(
- DisplayDataItemMatcher('args', str(args)))
- if kwargs:
- expected_items.append(
- DisplayDataItemMatcher('kwargs', str(kwargs)))
- hc.assert_that(dd.items, hc.contains_inanyorder(*expected_items))
-
- individual_test_per_key_dd(combine.Sample.FixedSizePerKey,
- args=(5,),
- kwargs={})
- individual_test_per_key_dd(combine.Sample.FixedSizeGlobally,
- args=(8,),
- kwargs={'arg': 9})
+ hc.assert_that(
+ dd.items,
+ hc.contains_inanyorder(DisplayDataItemMatcher('n', transform._n)))
+
+ individual_test_per_key_dd(combine.Sample.FixedSizePerKey, 5)
+ individual_test_per_key_dd(combine.Sample.FixedSizeGlobally, 5)
def test_combine_globally_display_data(self):
transform = beam.CombineGlobally(combine.Smallest(5))
With regards,
Apache Git Services