You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by GitBox <gi...@apache.org> on 2019/01/14 20:03:16 UTC

[beam] Diff for: [GitHub] aaltay merged pull request #7498: Made SampleFixedSizeGlobally and FixedSizePerKey PTransform classes

diff --git a/sdks/python/apache_beam/transforms/combiners.py b/sdks/python/apache_beam/transforms/combiners.py
index e3df43632b21..dbb143e6e168 100644
--- a/sdks/python/apache_beam/transforms/combiners.py
+++ b/sdks/python/apache_beam/transforms/combiners.py
@@ -522,15 +522,35 @@ class Sample(object):
   """Combiners for sampling n elements without replacement."""
   # pylint: disable=no-self-argument
 
-  @staticmethod
-  @ptransform.ptransform_fn
-  def FixedSizeGlobally(pcoll, n):
-    return pcoll | core.CombineGlobally(SampleCombineFn(n))
+  class FixedSizeGlobally(ptransform.PTransform):
+    """Sample n elements from the input PCollection without replacement."""
 
-  @staticmethod
-  @ptransform.ptransform_fn
-  def FixedSizePerKey(pcoll, n):
-    return pcoll | core.CombinePerKey(SampleCombineFn(n))
+    def __init__(self, n):
+      self._n = n
+
+    def expand(self, pcoll):
+      return pcoll | core.CombineGlobally(SampleCombineFn(self._n))
+
+    def display_data(self):
+      return {'n': self._n}
+
+    def default_label(self):
+      return 'FixedSizeGlobally(%d)' % self._n
+
+  class FixedSizePerKey(ptransform.PTransform):
+    """Sample n elements associated with each key without replacement."""
+
+    def __init__(self, n):
+      self._n = n
+
+    def expand(self, pcoll):
+      return pcoll | core.CombinePerKey(SampleCombineFn(self._n))
+
+    def display_data(self):
+      return {'n': self._n}
+
+    def default_label(self):
+      return 'FixedSizePerKey(%d)' % self._n
 
 
 @with_input_types(T)
diff --git a/sdks/python/apache_beam/transforms/combiners_test.py b/sdks/python/apache_beam/transforms/combiners_test.py
index 637a41f3dcb5..3db019a03599 100644
--- a/sdks/python/apache_beam/transforms/combiners_test.py
+++ b/sdks/python/apache_beam/transforms/combiners_test.py
@@ -177,26 +177,16 @@ def individual_test_per_key_dd(combineFn):
     individual_test_per_key_dd(combine.Largest(5))
 
   def test_combine_sample_display_data(self):
-    def individual_test_per_key_dd(sampleFn, args, kwargs):
-      trs = [sampleFn(*args, **kwargs)]
+    def individual_test_per_key_dd(sampleFn, n):
+      trs = [sampleFn(n)]
       for transform in trs:
         dd = DisplayData.create_from(transform)
-        expected_items = [
-            DisplayDataItemMatcher('fn', transform._fn.__name__)]
-        if args:
-          expected_items.append(
-              DisplayDataItemMatcher('args', str(args)))
-        if kwargs:
-          expected_items.append(
-              DisplayDataItemMatcher('kwargs', str(kwargs)))
-        hc.assert_that(dd.items, hc.contains_inanyorder(*expected_items))
-
-    individual_test_per_key_dd(combine.Sample.FixedSizePerKey,
-                               args=(5,),
-                               kwargs={})
-    individual_test_per_key_dd(combine.Sample.FixedSizeGlobally,
-                               args=(8,),
-                               kwargs={'arg':  9})
+        hc.assert_that(
+            dd.items,
+            hc.contains_inanyorder(DisplayDataItemMatcher('n', transform._n)))
+
+    individual_test_per_key_dd(combine.Sample.FixedSizePerKey, 5)
+    individual_test_per_key_dd(combine.Sample.FixedSizeGlobally, 5)
 
   def test_combine_globally_display_data(self):
     transform = beam.CombineGlobally(combine.Smallest(5))


With regards,
Apache Git Services