You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by ro...@apache.org on 2017/10/14 00:13:55 UTC
[2/2] beam git commit: Add an element batching transform.
Add an element batching transform.
Project: http://git-wip-us.apache.org/repos/asf/beam/repo
Commit: http://git-wip-us.apache.org/repos/asf/beam/commit/d226c767
Tree: http://git-wip-us.apache.org/repos/asf/beam/tree/d226c767
Diff: http://git-wip-us.apache.org/repos/asf/beam/diff/d226c767
Branch: refs/heads/master
Commit: d226c7679b9d94a40553609f31ecbfba72559e8a
Parents: 3dc7559
Author: Robert Bradshaw <ro...@gmail.com>
Authored: Mon Oct 9 16:46:19 2017 -0700
Committer: Robert Bradshaw <ro...@gmail.com>
Committed: Fri Oct 13 17:13:41 2017 -0700
----------------------------------------------------------------------
sdks/python/apache_beam/transforms/util.py | 260 +++++++++++++++++++
sdks/python/apache_beam/transforms/util_test.py | 108 ++++++++
2 files changed, 368 insertions(+)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/beam/blob/d226c767/sdks/python/apache_beam/transforms/util.py
----------------------------------------------------------------------
diff --git a/sdks/python/apache_beam/transforms/util.py b/sdks/python/apache_beam/transforms/util.py
index 647781f..85d4975 100644
--- a/sdks/python/apache_beam/transforms/util.py
+++ b/sdks/python/apache_beam/transforms/util.py
@@ -20,14 +20,25 @@
from __future__ import absolute_import
+import collections
+import contextlib
+import time
+
+from apache_beam import typehints
+from apache_beam.metrics import Metrics
+from apache_beam.transforms import window
from apache_beam.transforms.core import CombinePerKey
+from apache_beam.transforms.core import DoFn
from apache_beam.transforms.core import Flatten
from apache_beam.transforms.core import GroupByKey
from apache_beam.transforms.core import Map
+from apache_beam.transforms.core import ParDo
from apache_beam.transforms.ptransform import PTransform
from apache_beam.transforms.ptransform import ptransform_fn
+from apache_beam.utils import windowed_value
__all__ = [
+ 'BatchElements',
'CoGroupByKey',
'Keys',
'KvSwap',
@@ -36,6 +47,9 @@ __all__ = [
]
+T = typehints.TypeVariable('T')
+
+
class CoGroupByKey(PTransform):
"""Groups results across several PCollections by key.
@@ -163,3 +177,249 @@ def RemoveDuplicates(pcoll): # pylint: disable=invalid-name
| 'ToPairs' >> Map(lambda v: (v, None))
| 'Group' >> CombinePerKey(lambda vs: None)
| 'RemoveDuplicates' >> Keys())
+
+
+class _BatchSizeEstimator(object):
+ """Estimates the best size for batches given historical timing.
+ """
+
+ _MAX_DATA_POINTS = 100
+ _MAX_GROWTH_FACTOR = 2
+
+ def __init__(self,
+ min_batch_size=1,
+ max_batch_size=1000,
+ target_batch_overhead=.1,
+ target_batch_duration_secs=1,
+ clock=time.time):
+ if min_batch_size > max_batch_size:
+ raise ValueError("Minimum (%s) must not be greater than maximum (%s)" % (
+ min_batch_size, max_batch_size))
+ if target_batch_overhead and not 0 < target_batch_overhead <= 1:
+ raise ValueError("target_batch_overhead (%s) must be between 0 and 1" % (
+ target_batch_overhead))
+ if target_batch_duration_secs and target_batch_duration_secs <= 0:
+ raise ValueError("target_batch_duration_secs (%s) must be positive" % (
+ target_batch_duration_secs))
+ if max(0, target_batch_overhead, target_batch_duration_secs) == 0:
+ raise ValueError("At least one of target_batch_overhead or "
+ "target_batch_duration_secs must be positive.")
+ self._min_batch_size = min_batch_size
+ self._max_batch_size = max_batch_size
+ self._target_batch_overhead = target_batch_overhead
+ self._target_batch_duration_secs = target_batch_duration_secs
+ self._clock = clock
+ self._data = []
+ self._ignore_next_timing = False
+ self._size_distribution = Metrics.distribution(
+ 'BatchElements', 'batch_size')
+ self._time_distribution = Metrics.distribution(
+ 'BatchElements', 'msec_per_batch')
+ # Beam distributions only accept integer values, so we use this to
+ # accumulate under-reported values until they add up to whole milliseconds.
+ # (Milliseconds are chosen because that's conventionally used elsewhere in
+ # profiling-style counters.)
+ self._remainder_msecs = 0
+
+ def ignore_next_timing(self):
+ """Call to indicate the next timing should be ignored.
+
+ For example, the first emit of a ParDo operation is known to be anomalous
+ due to setup that may occur.
+ """
+ self._ignore_next_timing = False
+
+ @contextlib.contextmanager
+ def record_time(self, batch_size):
+ start = self._clock()
+ yield
+ elapsed = self._clock() - start
+ elapsed_msec = 1e3 * elapsed + self._remainder_msecs
+ self._size_distribution.update(batch_size)
+ self._time_distribution.update(int(elapsed_msec))
+ self._remainder_msecs = elapsed_msec - int(elapsed_msec)
+ if self._ignore_next_timing:
+ self._ignore_next_timing = False
+ else:
+ self._data.append((batch_size, elapsed))
+ if len(self._data) >= self._MAX_DATA_POINTS:
+ self._thin_data()
+
+ def _thin_data(self):
+ sorted_data = sorted(self._data)
+ odd_one_out = [sorted_data[-1]] if len(sorted_data) % 2 == 1 else []
+ # Sort the pairs by how different they are.
+ pairs = sorted(zip(sorted_data[::2], sorted_data[1::2]),
+ key=lambda ((x1, _1), (x2, _2)): x2 / x1)
+ # Keep the top 1/3 most different pairs, average the top 2/3 most similar.
+ threshold = 2 * len(pairs) / 3
+ self._data = (
+ list(sum(pairs[threshold:], ()))
+ + [((x1 + x2) / 2.0, (t1 + t2) / 2.0)
+ for (x1, t1), (x2, t2) in pairs[:threshold]]
+ + odd_one_out)
+
+ def next_batch_size(self):
+ if self._min_batch_size == self._max_batch_size:
+ return self._min_batch_size
+ elif len(self._data) < 1:
+ return self._min_batch_size
+ elif len(self._data) < 2:
+ # Force some variety so we have distinct batch sizes on which to do
+ # linear regression below.
+ return int(max(
+ min(self._max_batch_size,
+ self._min_batch_size * self._MAX_GROWTH_FACTOR),
+ self._min_batch_size + 1))
+
+ # Linear regression for y = a + bx, where x is batch size and y is time.
+ xs, ys = zip(*self._data)
+ n = float(len(self._data))
+ xbar = sum(xs) / n
+ ybar = sum(ys) / n
+ b = (sum([(x - xbar) * (y - ybar) for x, y in self._data])
+ / sum([(x - xbar)**2 for x in xs]))
+ a = ybar - b * xbar
+
+ # Avoid nonsensical or division-by-zero errors below due to noise.
+ a = max(a, 1e-10)
+ b = max(b, 1e-20)
+
+ last_batch_size = self._data[-1][0]
+ cap = min(last_batch_size * self._MAX_GROWTH_FACTOR, self._max_batch_size)
+
+ if self._target_batch_duration_secs:
+ # Solution to a + b*x = self._target_batch_duration_secs.
+ cap = min(cap, (self._target_batch_duration_secs - a) / b)
+
+ if self._target_batch_overhead:
+ # Solution to a / (a + b*x) = self._target_batch_overhead.
+ cap = min(cap, (a / b) * (1 / self._target_batch_overhead - 1))
+
+ # Avoid getting stuck at min_batch_size.
+ jitter = len(self._data) % 2
+ return int(max(self._min_batch_size + jitter, cap))
+
+
+class _GlobalWindowsBatchingDoFn(DoFn):
+ def __init__(self, batch_size_estimator):
+ self._batch_size_estimator = batch_size_estimator
+
+ def start_bundle(self):
+ self._batch = []
+ self._batch_size = self._batch_size_estimator.next_batch_size()
+ # The first emit often involves non-trivial setup.
+ self._batch_size_estimator.ignore_next_timing()
+
+ def process(self, element):
+ self._batch.append(element)
+ if len(self._batch) >= self._batch_size:
+ with self._batch_size_estimator.record_time(self._batch_size):
+ yield self._batch
+ self._batch = []
+ self._batch_size = self._batch_size_estimator.next_batch_size()
+
+ def finish_bundle(self):
+ if self._batch:
+ with self._batch_size_estimator.record_time(self._batch_size):
+ yield window.GlobalWindows.windowed_value(self._batch)
+ self._batch = None
+ self._batch_size = self._batch_size_estimator.next_batch_size()
+
+
+class _WindowAwareBatchingDoFn(DoFn):
+
+ _MAX_LIVE_WINDOWS = 10
+
+ def __init__(self, batch_size_estimator):
+ self._batch_size_estimator = batch_size_estimator
+
+ def start_bundle(self):
+ self._batches = collections.defaultdict(list)
+ self._batch_size = self._batch_size_estimator.next_batch_size()
+ # The first emit often involves non-trivial setup.
+ self._batch_size_estimator.ignore_next_timing()
+
+ def process(self, element, window=DoFn.WindowParam):
+ self._batches[window].append(element)
+ if len(self._batches[window]) >= self._batch_size:
+ with self._batch_size_estimator.record_time(self._batch_size):
+ yield windowed_value.WindowedValue(
+ self._batches[window], window.max_timestamp(), (window,))
+ del self._batches[window]
+ self._batch_size = self._batch_size_estimator.next_batch_size()
+ elif len(self._batches) > self._MAX_LIVE_WINDOWS:
+ window, _ = sorted(
+ self._batches.items(),
+ key=lambda window_batch: len(window_batch[1]),
+ reverse=True)[0]
+ with self._batch_size_estimator.record_time(self._batch_size):
+ yield windowed_value.WindowedValue(
+ self._batches[window], window.max_timestamp(), (window,))
+ del self._batches[window]
+ self._batch_size = self._batch_size_estimator.next_batch_size()
+
+ def finish_bundle(self):
+ for window, batch in self._batches.items():
+ if batch:
+ with self._batch_size_estimator.record_time(self._batch_size):
+ yield windowed_value.WindowedValue(
+ batch, window.max_timestamp(), (window,))
+ self._batches = None
+ self._batch_size = self._batch_size_estimator.next_batch_size()
+
+
+@typehints.with_input_types(T)
+@typehints.with_output_types(typehints.List[T])
+class BatchElements(PTransform):
+ """A Transform that batches elements for amortized processing.
+
+ This transform is designed to precede operations whose processing cost
+ is of the form
+
+ time = fixed_cost + num_elements * per_element_cost
+
+ where the per element cost is (often significantly) smaller than the fixed
+ cost and could be amortized over multiple elements. It consumes a PCollection
+ of element type T and produces a PCollection of element type List[T].
+
+ This transform attempts to find the best batch size between the minimim
+ and maximum parameters by profiling the time taken by (fused) downstream
+ operations. For a fixed batch size, set the min and max to be equal.
+
+ Elements are batched per-window and batches emitted in the window
+ corresponding to its contents.
+
+ Args:
+ min_batch_size: (optional) the smallest number of elements per batch
+ max_batch_size: (optional) the largest number of elements per batch
+ target_batch_overhead: (optional) a target for fixed_cost / time,
+ as used in the formula above
+ target_batch_duration_secs: (optional) a target for total time per bundle,
+ in seconds
+ clock: (optional) an alternative to time.time for measuring the cost of
+ donwstream operations (mostly for testing)
+ """
+ def __init__(self,
+ min_batch_size=1,
+ max_batch_size=1000,
+ target_batch_overhead=.05,
+ target_batch_duration_secs=1,
+ clock=time.time):
+ self._batch_size_estimator = _BatchSizeEstimator(
+ min_batch_size=min_batch_size,
+ max_batch_size=max_batch_size,
+ target_batch_overhead=target_batch_overhead,
+ target_batch_duration_secs=target_batch_duration_secs,
+ clock=clock)
+
+ def expand(self, pcoll):
+ if getattr(pcoll.pipeline.runner, 'is_streaming', False):
+ raise NotImplementedError("Requires stateful processing (BEAM-2687)")
+ elif pcoll.windowing.is_default():
+ # This is the same logic as _GlobalWindowsBatchingDoFn, but optimized
+ # for that simpler case.
+ return pcoll | ParDo(_GlobalWindowsBatchingDoFn(
+ self._batch_size_estimator))
+ else:
+ return pcoll | ParDo(_WindowAwareBatchingDoFn(self._batch_size_estimator))
http://git-wip-us.apache.org/repos/asf/beam/blob/d226c767/sdks/python/apache_beam/transforms/util_test.py
----------------------------------------------------------------------
diff --git a/sdks/python/apache_beam/transforms/util_test.py b/sdks/python/apache_beam/transforms/util_test.py
new file mode 100644
index 0000000..6064e2c
--- /dev/null
+++ b/sdks/python/apache_beam/transforms/util_test.py
@@ -0,0 +1,108 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""Unit tests for the transform.util classes."""
+
+import time
+import unittest
+
+import apache_beam as beam
+from apache_beam.testing.test_pipeline import TestPipeline
+from apache_beam.testing.util import assert_that
+from apache_beam.testing.util import equal_to
+from apache_beam.transforms import util
+from apache_beam.transforms import window
+
+
+class FakeClock(object):
+
+ def __init__(self):
+ self._now = time.time()
+
+ def __call__(self):
+ return self._now
+
+ def sleep(self, duration):
+ self._now += duration
+
+
+class BatchElementsTest(unittest.TestCase):
+
+ def test_constant_batch(self):
+ # Assumes a single bundle...
+ with TestPipeline() as p:
+ res = (
+ p
+ | beam.Create(range(35))
+ | util.BatchElements(min_batch_size=10, max_batch_size=10)
+ | beam.Map(len))
+ assert_that(res, equal_to([10, 10, 10, 5]))
+
+ def test_grows_to_max_batch(self):
+ # Assumes a single bundle...
+ with TestPipeline() as p:
+ res = (
+ p
+ | beam.Create(range(164))
+ | util.BatchElements(
+ min_batch_size=1, max_batch_size=50, clock=FakeClock())
+ | beam.Map(len))
+ assert_that(res, equal_to([1, 1, 2, 4, 8, 16, 32, 50, 50]))
+
+ def test_windowed_batches(self):
+ # Assumes a single bundle, in order...
+ with TestPipeline() as p:
+ res = (
+ p
+ | beam.Create(range(47))
+ | beam.Map(lambda t: window.TimestampedValue(t, t))
+ | beam.WindowInto(window.FixedWindows(30))
+ | util.BatchElements(
+ min_batch_size=5, max_batch_size=10, clock=FakeClock())
+ | beam.Map(len))
+ assert_that(res, equal_to([
+ 5, 5, 10, 10, # elements in [0, 30)
+ 10, 7, # elements in [30, 47)
+ ]))
+
+ def test_target_duration(self):
+ clock = FakeClock()
+ batch_estimator = util._BatchSizeEstimator(
+ target_batch_overhead=None, target_batch_duration_secs=10, clock=clock)
+ batch_duration = lambda batch_size: 1 + .7 * batch_size
+ # 1 + 12 * .7 is as close as we can get to 10 as possible.
+ expected_sizes = [1, 2, 4, 8, 12, 12, 12]
+ actual_sizes = []
+ for _ in range(len(expected_sizes)):
+ actual_sizes.append(batch_estimator.next_batch_size())
+ with batch_estimator.record_time(actual_sizes[-1]):
+ clock.sleep(batch_duration(actual_sizes[-1]))
+ self.assertEqual(expected_sizes, actual_sizes)
+
+ def test_target_overhead(self):
+ clock = FakeClock()
+ batch_estimator = util._BatchSizeEstimator(
+ target_batch_overhead=.05, target_batch_duration_secs=None, clock=clock)
+ batch_duration = lambda batch_size: 1 + .7 * batch_size
+ # At 27 items, a batch takes ~20 seconds with 5% (~1 second) overhead.
+ expected_sizes = [1, 2, 4, 8, 16, 27, 27, 27]
+ actual_sizes = []
+ for _ in range(len(expected_sizes)):
+ actual_sizes.append(batch_estimator.next_batch_size())
+ with batch_estimator.record_time(actual_sizes[-1]):
+ clock.sleep(batch_duration(actual_sizes[-1]))
+ self.assertEqual(expected_sizes, actual_sizes)