You are viewing a plain text version of this content. The canonical link for it is here.
Posted to github@beam.apache.org by GitBox <gi...@apache.org> on 2021/06/21 20:09:14 UTC

[GitHub] [beam] yifanmai commented on a change in pull request #15029: More easily substituted CoGroupByKey transform.

yifanmai commented on a change in pull request #15029:
URL: https://github.com/apache/beam/pull/15029#discussion_r655669107



##########
File path: sdks/python/apache_beam/transforms/util.py
##########
@@ -144,66 +144,76 @@ class CoGroupByKey(PTransform):
       (or if there's a chance there may be none), this argument is the only way
       to provide pipeline information, and should be considered mandatory.
   """
-  def __init__(self, **kwargs):
-    super(CoGroupByKey, self).__init__()
-    self.pipeline = kwargs.pop('pipeline', None)
-    if kwargs:
-      raise ValueError('Unexpected keyword arguments: %s' % list(kwargs.keys()))
+  def __init__(self, *, pipeline=None):
+    self.pipeline = pipeline
 
   def _extract_input_pvalues(self, pvalueish):
     try:
       # If this works, it's a dict.
       return pvalueish, tuple(pvalueish.values())
     except AttributeError:
+      # Cast iterables a tuple so we can do re-iteration.
       pcolls = tuple(pvalueish)
       return pcolls, pcolls
 
   def expand(self, pcolls):
-    """Performs CoGroupByKey on argument pcolls; see class docstring."""
-
-    # For associating values in K-V pairs with the PCollections they came from.
-    def _pair_tag_with_value(key_value, tag):
-      (key, value) = key_value
-      return (key, (tag, value))
-
-    # Creates the key, value pairs for the output PCollection. Values are either
-    # lists or dicts (per the class docstring), initialized by the result of
-    # result_ctor(result_ctor_arg).
-    def _merge_tagged_vals_under_key(key_grouped, result_ctor, result_ctor_arg):
-      (key, grouped) = key_grouped
-      result_value = result_ctor(result_ctor_arg)
-      for tag, value in grouped:
-        result_value[tag].append(value)
-      return (key, result_value)
+    if isinstance(pcolls, dict):
+      if all(isinstance(tag, str) and len(tag) < 10 for tag in pcolls.keys()):
+        # Small, string tags. Pass them as data.
+        pcolls_dict = pcolls
+        post_process = None
+      else:
+        # Pass the tags in the post_process closure.
+        tags = list(pcolls.keys())
+        pcolls_dict = {str(ix): pcolls[tag] for (ix, tag) in enumerate(tags)}
+        post_process = lambda vs: {
+            tag: vs[str(ix)]
+            for (ix, tag) in enumerate(tags)
+        }
+    else:
+      # Tags are tuple indices.
+      num_tags = len(pcolls)
+      pcolls_dict = {str(ix): pcolls[ix] for ix in range(num_tags)}
+      post_process = lambda vs: tuple(vs[str(ix)] for ix in range(num_tags))
+
+    result = pcolls_dict | _CoGBKImpl(pipeline=self.pipeline)
+    if post_process:
+      return result | MapTuple(lambda k, vs: (k, post_process(vs)))

Review comment:
       Add an informative stage name here? (e.g. "PostProcessTags" or "RestoreTags")

##########
File path: sdks/python/apache_beam/transforms/ptransform_test.py
##########
@@ -801,6 +801,32 @@ def test_co_group_by_key_on_dict(self):
               'X': [4], 'Y': [7, 8]
           })]))
 
+  def test_co_group_by_key_on_tuple_dict(self):

Review comment:
       optional: `test_co_group_by_key_on_dict_with_tuple_keys`?




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org