You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by ro...@apache.org on 2021/12/23 20:48:03 UTC

[beam] branch master updated: Better type inference for GroupBy. (#16318)

This is an automated email from the ASF dual-hosted git repository.

robertwb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new c1448f1  Better type inference for GroupBy. (#16318)
c1448f1 is described below

commit c1448f10b6f0115c8e32560acac010e1c798e4ff
Author: Robert Bradshaw <ro...@google.com>
AuthorDate: Thu Dec 23 12:46:54 2021 -0800

    Better type inference for GroupBy. (#16318)
    
    Observed a 2-3x performance improvement for simple tests.
    
    Also allows Rows with Any components to be used as keys.
---
 sdks/python/apache_beam/coders/row_coder.py           | 17 +++++++++++++++--
 sdks/python/apache_beam/transforms/core.py            | 18 +++++++++++-------
 sdks/python/apache_beam/transforms/ptransform_test.py | 12 ++++++++++++
 3 files changed, 38 insertions(+), 9 deletions(-)

diff --git a/sdks/python/apache_beam/coders/row_coder.py b/sdks/python/apache_beam/coders/row_coder.py
index 5219d9c..815dcd0 100644
--- a/sdks/python/apache_beam/coders/row_coder.py
+++ b/sdks/python/apache_beam/coders/row_coder.py
@@ -47,7 +47,7 @@ class RowCoder(FastCoder):
 
   Implements the beam:coder:row:v1 standard coder spec.
   """
-  def __init__(self, schema):
+  def __init__(self, schema, force_deterministic=False):
     """Initializes a :class:`RowCoder`.
 
     Args:
@@ -64,6 +64,11 @@ class RowCoder(FastCoder):
     self.components = [
         _nonnull_coder_from_type(field.type) for field in self.schema.fields
     ]
+    if force_deterministic:
+      self.components = [
+          c.as_deterministic_coder(force_deterministic) for c in self.components
+      ]
+    self.forced_deterministic = bool(force_deterministic)
 
   def _create_impl(self):
     return RowCoderImpl(self.schema, self.components)
@@ -71,6 +76,12 @@ class RowCoder(FastCoder):
   def is_deterministic(self):
     return all(c.is_deterministic() for c in self.components)
 
+  def as_deterministic_coder(self, step_label, error_message=None):
+    if self.is_deterministic():
+      return self
+    else:
+      return RowCoder(self.schema, error_message or step_label)
+
   def to_type_hint(self):
     return self._type_hint
 
@@ -78,7 +89,9 @@ class RowCoder(FastCoder):
     return hash(self.schema.SerializeToString())
 
   def __eq__(self, other):
-    return type(self) == type(other) and self.schema == other.schema
+    return (
+        type(self) == type(other) and self.schema == other.schema and
+        self.forced_deterministic == other.forced_deterministic)
 
   def to_runner_api_parameter(self, unused_context):
     return (common_urns.coders.ROW.urn, self.schema, [])
diff --git a/sdks/python/apache_beam/transforms/core.py b/sdks/python/apache_beam/transforms/core.py
index bd999f8..a6eeac3 100644
--- a/sdks/python/apache_beam/transforms/core.py
+++ b/sdks/python/apache_beam/transforms/core.py
@@ -54,6 +54,7 @@ from apache_beam.transforms.window import TimestampCombiner
 from apache_beam.transforms.window import TimestampedValue
 from apache_beam.transforms.window import WindowedValue
 from apache_beam.transforms.window import WindowFn
+from apache_beam.typehints import row_type
 from apache_beam.typehints import trivial_inference
 from apache_beam.typehints.decorators import TypeCheckError
 from apache_beam.typehints.decorators import WithTypeHints
@@ -2719,12 +2720,15 @@ class GroupBy(PTransform):
       key_exprs = [expr for _, expr in self._key_fields]
       return lambda element: key_type(*(expr(element) for expr in key_exprs))
 
-  def _key_type_hint(self):
+  def _key_type_hint(self, input_type):
     if not self._force_tuple_keys and len(self._key_fields) == 1:
-      return typing.Any
+      expr = self._key_fields[0][1]
+      return trivial_inference.infer_return_type(expr, [input_type])
     else:
-      return _dynamic_named_tuple(
-          'Key', tuple(name for name, _ in self._key_fields))
+      return row_type.RowTypeConstraint([
+          (name, trivial_inference.infer_return_type(expr, [input_type]))
+          for (name, expr) in self._key_fields
+      ])
 
   def default_label(self):
     return 'GroupBy(%s)' % ', '.join(name for name, _ in self._key_fields)
@@ -2734,7 +2738,7 @@ class GroupBy(PTransform):
     return (
         pcoll
         | Map(lambda x: (self._key_func()(x), x)).with_output_types(
-            typehints.Tuple[self._key_type_hint(), input_type])
+            typehints.Tuple[self._key_type_hint(input_type), input_type])
         | GroupByKey())
 
 
@@ -2785,7 +2789,8 @@ class _GroupAndAggregate(PTransform):
     result_fields = tuple(name
                           for name, _ in self._grouping._key_fields) + tuple(
                               dest for _, __, dest in self._aggregations)
-    key_type_hint = self._grouping.force_tuple_keys(True)._key_type_hint()
+    key_type_hint = self._grouping.force_tuple_keys(True)._key_type_hint(
+        pcoll.element_type)
 
     return (
         pcoll
@@ -2832,7 +2837,6 @@ class Select(PTransform):
                                 for name, expr in self._fields}))
 
   def infer_output_type(self, input_type):
-    from apache_beam.typehints import row_type
     return row_type.RowTypeConstraint([
         (name, trivial_inference.infer_return_type(expr, [input_type]))
         for (name, expr) in self._fields
diff --git a/sdks/python/apache_beam/transforms/ptransform_test.py b/sdks/python/apache_beam/transforms/ptransform_test.py
index 7b3e10a..191ba8a 100644
--- a/sdks/python/apache_beam/transforms/ptransform_test.py
+++ b/sdks/python/apache_beam/transforms/ptransform_test.py
@@ -1000,6 +1000,18 @@ class TestGroupBy(unittest.TestCase):
               beam.Row(square=4, big=True, sum=2, positive=True),     # [2]
           ]))
 
+  def test_pickled_field(self):
+    with TestPipeline() as p:
+      assert_that(
+          p
+          | beam.Create(['a', 'a', 'b'])
+          | beam.Map(
+              lambda s: beam.Row(
+                  key1=PickledObject(s), key2=s.upper(), value=0))
+          | beam.GroupBy('key1', 'key2')
+          | beam.MapTuple(lambda k, vs: (k.key1.value, k.key2, len(list(vs)))),
+          equal_to([('a', 'A', 2), ('b', 'B', 1)]))
+
 
 class SelectTest(unittest.TestCase):
   def test_simple(self):