You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by ro...@apache.org on 2021/12/23 20:48:03 UTC
[beam] branch master updated: Better type inference for GroupBy. (#16318)
This is an automated email from the ASF dual-hosted git repository.
robertwb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git
The following commit(s) were added to refs/heads/master by this push:
new c1448f1 Better type inference for GroupBy. (#16318)
c1448f1 is described below
commit c1448f10b6f0115c8e32560acac010e1c798e4ff
Author: Robert Bradshaw <ro...@google.com>
AuthorDate: Thu Dec 23 12:46:54 2021 -0800
Better type inference for GroupBy. (#16318)
Observed a 2-3x performance improvement for simple tests.
Also allows Rows with Any components to be used as keys.
---
sdks/python/apache_beam/coders/row_coder.py | 17 +++++++++++++++--
sdks/python/apache_beam/transforms/core.py | 18 +++++++++++-------
sdks/python/apache_beam/transforms/ptransform_test.py | 12 ++++++++++++
3 files changed, 38 insertions(+), 9 deletions(-)
diff --git a/sdks/python/apache_beam/coders/row_coder.py b/sdks/python/apache_beam/coders/row_coder.py
index 5219d9c..815dcd0 100644
--- a/sdks/python/apache_beam/coders/row_coder.py
+++ b/sdks/python/apache_beam/coders/row_coder.py
@@ -47,7 +47,7 @@ class RowCoder(FastCoder):
Implements the beam:coder:row:v1 standard coder spec.
"""
- def __init__(self, schema):
+ def __init__(self, schema, force_deterministic=False):
"""Initializes a :class:`RowCoder`.
Args:
@@ -64,6 +64,11 @@ class RowCoder(FastCoder):
self.components = [
_nonnull_coder_from_type(field.type) for field in self.schema.fields
]
+ if force_deterministic:
+ self.components = [
+ c.as_deterministic_coder(force_deterministic) for c in self.components
+ ]
+ self.forced_deterministic = bool(force_deterministic)
def _create_impl(self):
return RowCoderImpl(self.schema, self.components)
@@ -71,6 +76,12 @@ class RowCoder(FastCoder):
def is_deterministic(self):
return all(c.is_deterministic() for c in self.components)
+ def as_deterministic_coder(self, step_label, error_message=None):
+ if self.is_deterministic():
+ return self
+ else:
+ return RowCoder(self.schema, error_message or step_label)
+
def to_type_hint(self):
return self._type_hint
@@ -78,7 +89,9 @@ class RowCoder(FastCoder):
return hash(self.schema.SerializeToString())
def __eq__(self, other):
- return type(self) == type(other) and self.schema == other.schema
+ return (
+ type(self) == type(other) and self.schema == other.schema and
+ self.forced_deterministic == other.forced_deterministic)
def to_runner_api_parameter(self, unused_context):
return (common_urns.coders.ROW.urn, self.schema, [])
diff --git a/sdks/python/apache_beam/transforms/core.py b/sdks/python/apache_beam/transforms/core.py
index bd999f8..a6eeac3 100644
--- a/sdks/python/apache_beam/transforms/core.py
+++ b/sdks/python/apache_beam/transforms/core.py
@@ -54,6 +54,7 @@ from apache_beam.transforms.window import TimestampCombiner
from apache_beam.transforms.window import TimestampedValue
from apache_beam.transforms.window import WindowedValue
from apache_beam.transforms.window import WindowFn
+from apache_beam.typehints import row_type
from apache_beam.typehints import trivial_inference
from apache_beam.typehints.decorators import TypeCheckError
from apache_beam.typehints.decorators import WithTypeHints
@@ -2719,12 +2720,15 @@ class GroupBy(PTransform):
key_exprs = [expr for _, expr in self._key_fields]
return lambda element: key_type(*(expr(element) for expr in key_exprs))
- def _key_type_hint(self):
+ def _key_type_hint(self, input_type):
if not self._force_tuple_keys and len(self._key_fields) == 1:
- return typing.Any
+ expr = self._key_fields[0][1]
+ return trivial_inference.infer_return_type(expr, [input_type])
else:
- return _dynamic_named_tuple(
- 'Key', tuple(name for name, _ in self._key_fields))
+ return row_type.RowTypeConstraint([
+ (name, trivial_inference.infer_return_type(expr, [input_type]))
+ for (name, expr) in self._key_fields
+ ])
def default_label(self):
return 'GroupBy(%s)' % ', '.join(name for name, _ in self._key_fields)
@@ -2734,7 +2738,7 @@ class GroupBy(PTransform):
return (
pcoll
| Map(lambda x: (self._key_func()(x), x)).with_output_types(
- typehints.Tuple[self._key_type_hint(), input_type])
+ typehints.Tuple[self._key_type_hint(input_type), input_type])
| GroupByKey())
@@ -2785,7 +2789,8 @@ class _GroupAndAggregate(PTransform):
result_fields = tuple(name
for name, _ in self._grouping._key_fields) + tuple(
dest for _, __, dest in self._aggregations)
- key_type_hint = self._grouping.force_tuple_keys(True)._key_type_hint()
+ key_type_hint = self._grouping.force_tuple_keys(True)._key_type_hint(
+ pcoll.element_type)
return (
pcoll
@@ -2832,7 +2837,6 @@ class Select(PTransform):
for name, expr in self._fields}))
def infer_output_type(self, input_type):
- from apache_beam.typehints import row_type
return row_type.RowTypeConstraint([
(name, trivial_inference.infer_return_type(expr, [input_type]))
for (name, expr) in self._fields
diff --git a/sdks/python/apache_beam/transforms/ptransform_test.py b/sdks/python/apache_beam/transforms/ptransform_test.py
index 7b3e10a..191ba8a 100644
--- a/sdks/python/apache_beam/transforms/ptransform_test.py
+++ b/sdks/python/apache_beam/transforms/ptransform_test.py
@@ -1000,6 +1000,18 @@ class TestGroupBy(unittest.TestCase):
beam.Row(square=4, big=True, sum=2, positive=True), # [2]
]))
+ def test_pickled_field(self):
+ with TestPipeline() as p:
+ assert_that(
+ p
+ | beam.Create(['a', 'a', 'b'])
+ | beam.Map(
+ lambda s: beam.Row(
+ key1=PickledObject(s), key2=s.upper(), value=0))
+ | beam.GroupBy('key1', 'key2')
+ | beam.MapTuple(lambda k, vs: (k.key1.value, k.key2, len(list(vs)))),
+ equal_to([('a', 'A', 2), ('b', 'B', 1)]))
+
class SelectTest(unittest.TestCase):
def test_simple(self):