You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by ro...@apache.org on 2017/05/02 00:45:03 UTC
[2/4] beam git commit: Fn API support for Python.
http://git-wip-us.apache.org/repos/asf/beam/blob/a856fcf3/sdks/python/apache_beam/runners/worker/logger_test.py
----------------------------------------------------------------------
diff --git a/sdks/python/apache_beam/runners/worker/logger_test.py b/sdks/python/apache_beam/runners/worker/logger_test.py
new file mode 100644
index 0000000..cf3f692
--- /dev/null
+++ b/sdks/python/apache_beam/runners/worker/logger_test.py
@@ -0,0 +1,182 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""Tests for worker logging utilities."""
+
+import json
+import logging
+import sys
+import threading
+import unittest
+
+from apache_beam.runners.worker import logger
+
+
+class PerThreadLoggingContextTest(unittest.TestCase):
+
+ def thread_check_attribute(self, name):
+ self.assertFalse(name in logger.per_thread_worker_data.get_data())
+ with logger.PerThreadLoggingContext(**{name: 'thread-value'}):
+ self.assertEqual(
+ logger.per_thread_worker_data.get_data()[name], 'thread-value')
+ self.assertFalse(name in logger.per_thread_worker_data.get_data())
+
+ def test_per_thread_attribute(self):
+ self.assertFalse('xyz' in logger.per_thread_worker_data.get_data())
+ with logger.PerThreadLoggingContext(xyz='value'):
+ self.assertEqual(logger.per_thread_worker_data.get_data()['xyz'], 'value')
+ thread = threading.Thread(
+ target=self.thread_check_attribute, args=('xyz',))
+ thread.start()
+ thread.join()
+ self.assertEqual(logger.per_thread_worker_data.get_data()['xyz'], 'value')
+ self.assertFalse('xyz' in logger.per_thread_worker_data.get_data())
+
+ def test_set_when_undefined(self):
+ self.assertFalse('xyz' in logger.per_thread_worker_data.get_data())
+ with logger.PerThreadLoggingContext(xyz='value'):
+ self.assertEqual(logger.per_thread_worker_data.get_data()['xyz'], 'value')
+ self.assertFalse('xyz' in logger.per_thread_worker_data.get_data())
+
+ def test_set_when_already_defined(self):
+ self.assertFalse('xyz' in logger.per_thread_worker_data.get_data())
+ with logger.PerThreadLoggingContext(xyz='value'):
+ self.assertEqual(logger.per_thread_worker_data.get_data()['xyz'], 'value')
+ with logger.PerThreadLoggingContext(xyz='value2'):
+ self.assertEqual(
+ logger.per_thread_worker_data.get_data()['xyz'], 'value2')
+ self.assertEqual(logger.per_thread_worker_data.get_data()['xyz'], 'value')
+ self.assertFalse('xyz' in logger.per_thread_worker_data.get_data())
+
+
+class JsonLogFormatterTest(unittest.TestCase):
+
+ SAMPLE_RECORD = {
+ 'created': 123456.789, 'msecs': 789.654321,
+ 'msg': '%s:%d:%.2f', 'args': ('xyz', 4, 3.14),
+ 'levelname': 'WARNING',
+ 'process': 'pid', 'thread': 'tid',
+ 'name': 'name', 'filename': 'file', 'funcName': 'func',
+ 'exc_info': None}
+
+ SAMPLE_OUTPUT = {
+ 'timestamp': {'seconds': 123456, 'nanos': 789654321},
+ 'severity': 'WARN', 'message': 'xyz:4:3.14', 'thread': 'pid:tid',
+ 'job': 'jobid', 'worker': 'workerid', 'logger': 'name:file:func'}
+
+ def create_log_record(self, **kwargs):
+
+ class Record(object):
+
+ def __init__(self, **kwargs):
+ for k, v in kwargs.iteritems():
+ setattr(self, k, v)
+
+ return Record(**kwargs)
+
+ def test_basic_record(self):
+ formatter = logger.JsonLogFormatter(job_id='jobid', worker_id='workerid')
+ record = self.create_log_record(**self.SAMPLE_RECORD)
+ self.assertEqual(json.loads(formatter.format(record)), self.SAMPLE_OUTPUT)
+
+ def execute_multiple_cases(self, test_cases):
+ record = self.SAMPLE_RECORD
+ output = self.SAMPLE_OUTPUT
+ formatter = logger.JsonLogFormatter(job_id='jobid', worker_id='workerid')
+
+ for case in test_cases:
+ record['msg'] = case['msg']
+ record['args'] = case['args']
+ output['message'] = case['expected']
+
+ self.assertEqual(
+ json.loads(formatter.format(self.create_log_record(**record))),
+ output)
+
+ def test_record_with_format_character(self):
+ test_cases = [
+ {'msg': '%A', 'args': (), 'expected': '%A'},
+ {'msg': '%s', 'args': (), 'expected': '%s'},
+ {'msg': '%A%s', 'args': ('xy'), 'expected': '%A%s with args (xy)'},
+ {'msg': '%s%s', 'args': (1), 'expected': '%s%s with args (1)'},
+ ]
+
+ self.execute_multiple_cases(test_cases)
+
+ def test_record_with_arbitrary_messages(self):
+ test_cases = [
+ {'msg': ImportError('abc'), 'args': (), 'expected': 'abc'},
+ {'msg': TypeError('abc %s'), 'args': ('def'), 'expected': 'abc def'},
+ ]
+
+ self.execute_multiple_cases(test_cases)
+
+ def test_record_with_per_thread_info(self):
+ with logger.PerThreadLoggingContext(
+ work_item_id='workitem', stage_name='stage', step_name='step'):
+ formatter = logger.JsonLogFormatter(job_id='jobid', worker_id='workerid')
+ record = self.create_log_record(**self.SAMPLE_RECORD)
+ log_output = json.loads(formatter.format(record))
+ expected_output = dict(self.SAMPLE_OUTPUT)
+ expected_output.update(
+ {'work': 'workitem', 'stage': 'stage', 'step': 'step'})
+ self.assertEqual(log_output, expected_output)
+
+ def test_nested_with_per_thread_info(self):
+ formatter = logger.JsonLogFormatter(job_id='jobid', worker_id='workerid')
+ with logger.PerThreadLoggingContext(
+ work_item_id='workitem', stage_name='stage', step_name='step1'):
+ record = self.create_log_record(**self.SAMPLE_RECORD)
+ log_output1 = json.loads(formatter.format(record))
+
+ with logger.PerThreadLoggingContext(step_name='step2'):
+ record = self.create_log_record(**self.SAMPLE_RECORD)
+ log_output2 = json.loads(formatter.format(record))
+
+ record = self.create_log_record(**self.SAMPLE_RECORD)
+ log_output3 = json.loads(formatter.format(record))
+
+ record = self.create_log_record(**self.SAMPLE_RECORD)
+ log_output4 = json.loads(formatter.format(record))
+
+ self.assertEqual(log_output1, dict(
+ self.SAMPLE_OUTPUT, work='workitem', stage='stage', step='step1'))
+ self.assertEqual(log_output2, dict(
+ self.SAMPLE_OUTPUT, work='workitem', stage='stage', step='step2'))
+ self.assertEqual(log_output3, dict(
+ self.SAMPLE_OUTPUT, work='workitem', stage='stage', step='step1'))
+ self.assertEqual(log_output4, self.SAMPLE_OUTPUT)
+
+ def test_exception_record(self):
+ formatter = logger.JsonLogFormatter(job_id='jobid', worker_id='workerid')
+ try:
+ raise ValueError('Something')
+ except ValueError:
+ attribs = dict(self.SAMPLE_RECORD)
+ attribs.update({'exc_info': sys.exc_info()})
+ record = self.create_log_record(**attribs)
+ log_output = json.loads(formatter.format(record))
+ # Check if exception type, its message, and stack trace information are in.
+ exn_output = log_output.pop('exception')
+ self.assertNotEqual(exn_output.find('ValueError: Something'), -1)
+ self.assertNotEqual(exn_output.find('logger_test.py'), -1)
+ self.assertEqual(log_output, self.SAMPLE_OUTPUT)
+
+
+if __name__ == '__main__':
+ logging.getLogger().setLevel(logging.INFO)
+ unittest.main()
http://git-wip-us.apache.org/repos/asf/beam/blob/a856fcf3/sdks/python/apache_beam/runners/worker/opcounters.pxd
----------------------------------------------------------------------
diff --git a/sdks/python/apache_beam/runners/worker/opcounters.pxd b/sdks/python/apache_beam/runners/worker/opcounters.pxd
new file mode 100644
index 0000000..5c1079f
--- /dev/null
+++ b/sdks/python/apache_beam/runners/worker/opcounters.pxd
@@ -0,0 +1,45 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+cimport cython
+cimport libc.stdint
+
+from apache_beam.utils.counters cimport Counter
+
+
+cdef class SumAccumulator(object):
+ cdef libc.stdint.int64_t _value
+ cpdef update(self, libc.stdint.int64_t value)
+ cpdef libc.stdint.int64_t value(self)
+
+
+cdef class OperationCounters(object):
+ cdef public _counter_factory
+ cdef public Counter element_counter
+ cdef public Counter mean_byte_counter
+ cdef public coder_impl
+ cdef public SumAccumulator active_accumulator
+ cdef public libc.stdint.int64_t _sample_counter
+ cdef public libc.stdint.int64_t _next_sample
+
+ cpdef update_from(self, windowed_value)
+ cdef inline do_sample(self, windowed_value)
+ cpdef update_collect(self)
+
+ cdef libc.stdint.int64_t _compute_next_sample(self, libc.stdint.int64_t i)
+ cdef inline bint _should_sample(self)
+ cpdef bint should_sample(self)
http://git-wip-us.apache.org/repos/asf/beam/blob/a856fcf3/sdks/python/apache_beam/runners/worker/opcounters.py
----------------------------------------------------------------------
diff --git a/sdks/python/apache_beam/runners/worker/opcounters.py b/sdks/python/apache_beam/runners/worker/opcounters.py
new file mode 100644
index 0000000..56ce0db
--- /dev/null
+++ b/sdks/python/apache_beam/runners/worker/opcounters.py
@@ -0,0 +1,162 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+# cython: profile=True
+
+"""Counters collect the progress of the Worker for reporting to the service."""
+
+from __future__ import absolute_import
+import math
+import random
+
+from apache_beam.utils.counters import Counter
+
+
+class SumAccumulator(object):
+ """Accumulator for collecting byte counts."""
+
+ def __init__(self):
+ self._value = 0
+
+ def update(self, value):
+ self._value += value
+
+ def value(self):
+ return self._value
+
+
+class OperationCounters(object):
+ """The set of basic counters to attach to an Operation."""
+
+ def __init__(self, counter_factory, step_name, coder, output_index):
+ self._counter_factory = counter_factory
+ self.element_counter = counter_factory.get_counter(
+ '%s-out%d-ElementCount' % (step_name, output_index), Counter.SUM)
+ self.mean_byte_counter = counter_factory.get_counter(
+ '%s-out%d-MeanByteCount' % (step_name, output_index), Counter.MEAN)
+ self.coder_impl = coder.get_impl()
+ self.active_accumulator = None
+ self._sample_counter = 0
+ self._next_sample = 0
+
+ def update_from(self, windowed_value):
+ """Add one value to this counter."""
+ self.element_counter.update(1)
+ if self._should_sample():
+ self.do_sample(windowed_value)
+
+ def _observable_callback(self, inner_coder_impl, accumulator):
+ def _observable_callback_inner(value, is_encoded=False):
+ # TODO(ccy): If this stream is large, sample it as well.
+ # To do this, we'll need to compute the average size of elements
+ # in this stream to add the *total* size of this stream to accumulator.
+ # We'll also want make sure we sample at least some of this stream
+ # (as self.should_sample() may be sampling very sparsely by now).
+ if is_encoded:
+ size = len(value)
+ accumulator.update(size)
+ else:
+ accumulator.update(inner_coder_impl.estimate_size(value))
+ return _observable_callback_inner
+
+ def do_sample(self, windowed_value):
+ size, observables = (
+ self.coder_impl.get_estimated_size_and_observables(windowed_value))
+ if not observables:
+ self.mean_byte_counter.update(size)
+ else:
+ self.active_accumulator = SumAccumulator()
+ self.active_accumulator.update(size)
+ for observable, inner_coder_impl in observables:
+ observable.register_observer(
+ self._observable_callback(
+ inner_coder_impl, self.active_accumulator))
+
+ def update_collect(self):
+ """Collects the accumulated size estimates.
+
+ Now that the element has been processed, we ask our accumulator
+ for the total and store the result in a counter.
+ """
+ if self.active_accumulator is not None:
+ self.mean_byte_counter.update(self.active_accumulator.value())
+ self.active_accumulator = None
+
+ def _compute_next_sample(self, i):
+ # https://en.wikipedia.org/wiki/Reservoir_sampling#Fast_Approximation
+ gap = math.log(1.0 - random.random()) / math.log(1.0 - 10.0/i)
+ return i + math.floor(gap)
+
+ def _should_sample(self):
+ """Determines whether to sample the next element.
+
+ Size calculation can be expensive, so we don't do it for each element.
+ Because we need only an estimate of average size, we sample.
+
+ We always sample the first 10 elements, then the sampling rate
+ is approximately 10/N. After reading N elements, of the next N,
+ we will sample approximately 10*ln(2) (about 7) elements.
+
+ This algorithm samples at the same rate as Reservoir Sampling, but
+ it never throws away early results. (Because we keep only a
+ running accumulation, storage is not a problem, so there is no
+ need to discard earlier calculations.)
+
+ Because we accumulate and do not replace, our statistics are
+ biased toward early data. If the data are distributed uniformly,
+ this is not a problem. If the data change over time (i.e., the
+ element size tends to grow or shrink over time), our estimate will
+ show the bias. We could correct this by giving weight N to each
+ sample, since each sample is a stand-in for the N/(10*ln(2))
+ samples around it, which is proportional to N. Since we do not
+ expect biased data, for efficiency we omit the extra multiplication.
+ We could reduce the early-data bias by putting a lower bound on
+ the sampling rate.
+
+ Computing random.randint(1, self._sample_counter) for each element
+ is too slow, so when the sample size is big enough (we estimate 30
+ is big enough), we estimate the size of the gap after each sample.
+ This estimation allows us to call random much less often.
+
+ Returns:
+ True if it is time to compute another element's size.
+ """
+
+ self._sample_counter += 1
+ if self._next_sample == 0:
+ if random.randint(1, self._sample_counter) <= 10:
+ if self._sample_counter > 30:
+ self._next_sample = self._compute_next_sample(self._sample_counter)
+ return True
+ return False
+ elif self._sample_counter >= self._next_sample:
+ self._next_sample = self._compute_next_sample(self._sample_counter)
+ return True
+ return False
+
+ def should_sample(self):
+ # We create this separate method because the above "_should_sample()" method
+ # is marked as inline in Cython and thus can't be exposed to Python code.
+ return self._should_sample()
+
+ def __str__(self):
+ return '<%s [%s]>' % (self.__class__.__name__,
+ ', '.join([str(x) for x in self.__iter__()]))
+
+ def __repr__(self):
+ return '<%s %s at %s>' % (self.__class__.__name__,
+ [x for x in self.__iter__()], hex(id(self)))
http://git-wip-us.apache.org/repos/asf/beam/blob/a856fcf3/sdks/python/apache_beam/runners/worker/opcounters_test.py
----------------------------------------------------------------------
diff --git a/sdks/python/apache_beam/runners/worker/opcounters_test.py b/sdks/python/apache_beam/runners/worker/opcounters_test.py
new file mode 100644
index 0000000..74561b8
--- /dev/null
+++ b/sdks/python/apache_beam/runners/worker/opcounters_test.py
@@ -0,0 +1,149 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import logging
+import math
+import random
+import unittest
+
+from apache_beam import coders
+from apache_beam.runners.worker.opcounters import OperationCounters
+from apache_beam.transforms.window import GlobalWindows
+from apache_beam.utils.counters import CounterFactory
+
+
+# Classes to test that we can handle a variety of objects.
+# These have to be at top level so the pickler can find them.
+
+
+class OldClassThatDoesNotImplementLen: # pylint: disable=old-style-class
+
+ def __init__(self):
+ pass
+
+
+class ObjectThatDoesNotImplementLen(object):
+
+ def __init__(self):
+ pass
+
+
+class OperationCountersTest(unittest.TestCase):
+
+ def verify_counters(self, opcounts, expected_elements, expected_size=None):
+ self.assertEqual(expected_elements, opcounts.element_counter.value())
+ if expected_size is not None:
+ if math.isnan(expected_size):
+ self.assertTrue(math.isnan(opcounts.mean_byte_counter.value()))
+ else:
+ self.assertEqual(expected_size, opcounts.mean_byte_counter.value())
+
+ def test_update_int(self):
+ opcounts = OperationCounters(CounterFactory(), 'some-name',
+ coders.PickleCoder(), 0)
+ self.verify_counters(opcounts, 0)
+ opcounts.update_from(GlobalWindows.windowed_value(1))
+ self.verify_counters(opcounts, 1)
+
+ def test_update_str(self):
+ coder = coders.PickleCoder()
+ opcounts = OperationCounters(CounterFactory(), 'some-name',
+ coder, 0)
+ self.verify_counters(opcounts, 0, float('nan'))
+ value = GlobalWindows.windowed_value('abcde')
+ opcounts.update_from(value)
+ estimated_size = coder.estimate_size(value)
+ self.verify_counters(opcounts, 1, estimated_size)
+
+ def test_update_old_object(self):
+ coder = coders.PickleCoder()
+ opcounts = OperationCounters(CounterFactory(), 'some-name',
+ coder, 0)
+ self.verify_counters(opcounts, 0, float('nan'))
+ obj = OldClassThatDoesNotImplementLen()
+ value = GlobalWindows.windowed_value(obj)
+ opcounts.update_from(value)
+ estimated_size = coder.estimate_size(value)
+ self.verify_counters(opcounts, 1, estimated_size)
+
+ def test_update_new_object(self):
+ coder = coders.PickleCoder()
+ opcounts = OperationCounters(CounterFactory(), 'some-name',
+ coder, 0)
+ self.verify_counters(opcounts, 0, float('nan'))
+
+ obj = ObjectThatDoesNotImplementLen()
+ value = GlobalWindows.windowed_value(obj)
+ opcounts.update_from(value)
+ estimated_size = coder.estimate_size(value)
+ self.verify_counters(opcounts, 1, estimated_size)
+
+ def test_update_multiple(self):
+ coder = coders.PickleCoder()
+ total_size = 0
+ opcounts = OperationCounters(CounterFactory(), 'some-name',
+ coder, 0)
+ self.verify_counters(opcounts, 0, float('nan'))
+ value = GlobalWindows.windowed_value('abcde')
+ opcounts.update_from(value)
+ total_size += coder.estimate_size(value)
+ value = GlobalWindows.windowed_value('defghij')
+ opcounts.update_from(value)
+ total_size += coder.estimate_size(value)
+ self.verify_counters(opcounts, 2, float(total_size) / 2)
+ value = GlobalWindows.windowed_value('klmnop')
+ opcounts.update_from(value)
+ total_size += coder.estimate_size(value)
+ self.verify_counters(opcounts, 3, float(total_size) / 3)
+
+ def test_should_sample(self):
+ # Order of magnitude more buckets than highest constant in code under test.
+ buckets = [0] * 300
+ # The seed is arbitrary and exists just to ensure this test is robust.
+ # If you don't like this seed, try your own; the test should still pass.
+ random.seed(1717)
+ # Do enough runs that the expected hits even in the last buckets
+ # is big enough to expect some statistical smoothing.
+ total_runs = 10 * len(buckets)
+
+ # Fill the buckets.
+ for _ in xrange(total_runs):
+ opcounts = OperationCounters(CounterFactory(), 'some-name',
+ coders.PickleCoder(), 0)
+ for i in xrange(len(buckets)):
+ if opcounts.should_sample():
+ buckets[i] += 1
+
+ # Look at the buckets to see if they are likely.
+ for i in xrange(10):
+ self.assertEqual(total_runs, buckets[i])
+ for i in xrange(10, len(buckets)):
+ self.assertTrue(buckets[i] > 7 * total_runs / i,
+ 'i=%d, buckets[i]=%d, expected=%d, ratio=%f' % (
+ i, buckets[i],
+ 10 * total_runs / i,
+ buckets[i] / (10.0 * total_runs / i)))
+ self.assertTrue(buckets[i] < 14 * total_runs / i,
+ 'i=%d, buckets[i]=%d, expected=%d, ratio=%f' % (
+ i, buckets[i],
+ 10 * total_runs / i,
+ buckets[i] / (10.0 * total_runs / i)))
+
+
+if __name__ == '__main__':
+ logging.getLogger().setLevel(logging.INFO)
+ unittest.main()
http://git-wip-us.apache.org/repos/asf/beam/blob/a856fcf3/sdks/python/apache_beam/runners/worker/operation_specs.py
----------------------------------------------------------------------
diff --git a/sdks/python/apache_beam/runners/worker/operation_specs.py b/sdks/python/apache_beam/runners/worker/operation_specs.py
new file mode 100644
index 0000000..977e165
--- /dev/null
+++ b/sdks/python/apache_beam/runners/worker/operation_specs.py
@@ -0,0 +1,368 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""Worker utilities for representing MapTasks.
+
+Each MapTask represents a sequence of ParallelInstruction(s): read from a
+source, write to a sink, parallel do, etc.
+"""
+
+import collections
+
+from apache_beam import coders
+
+
+def build_worker_instruction(*args):
+ """Create an object representing a ParallelInstruction protobuf.
+
+ This will be a collections.namedtuple with a custom __str__ method.
+
+ Alas, this wrapper is not known to pylint, which thinks it creates
+ constants. You may have to put a disable=invalid-name pylint
+ annotation on any use of this, depending on your names.
+
+ Args:
+ *args: first argument is the name of the type to create. Should
+ start with "Worker". Second arguments is alist of the
+ attributes of this object.
+ Returns:
+ A new class, a subclass of tuple, that represents the protobuf.
+ """
+ tuple_class = collections.namedtuple(*args)
+ tuple_class.__str__ = worker_object_to_string
+ tuple_class.__repr__ = worker_object_to_string
+ return tuple_class
+
+
+def worker_printable_fields(workerproto):
+ """Returns the interesting fields of a Worker* object."""
+ return ['%s=%s' % (name, value)
+ # _asdict is the only way and cannot subclass this generated class
+ # pylint: disable=protected-access
+ for name, value in workerproto._asdict().iteritems()
+ # want to output value 0 but not None nor []
+ if (value or value == 0)
+ and name not in
+ ('coder', 'coders', 'output_coders',
+ 'elements',
+ 'combine_fn', 'serialized_fn', 'window_fn',
+ 'append_trailing_newlines', 'strip_trailing_newlines',
+ 'compression_type', 'context',
+ 'start_shuffle_position', 'end_shuffle_position',
+ 'shuffle_reader_config', 'shuffle_writer_config')]
+
+
+def worker_object_to_string(worker_object):
+ """Returns a string compactly representing a Worker* object."""
+ return '%s(%s)' % (worker_object.__class__.__name__,
+ ', '.join(worker_printable_fields(worker_object)))
+
+
+# All the following Worker* definitions will have these lint problems:
+# pylint: disable=invalid-name
+# pylint: disable=pointless-string-statement
+
+
+WorkerRead = build_worker_instruction(
+ 'WorkerRead', ['source', 'output_coders'])
+"""Worker details needed to read from a source.
+
+Attributes:
+ source: a source object.
+ output_coders: 1-tuple of the coder for the output.
+"""
+
+
+WorkerSideInputSource = build_worker_instruction(
+ 'WorkerSideInputSource', ['source', 'tag'])
+"""Worker details needed to read from a side input source.
+
+Attributes:
+ source: a source object.
+ tag: string tag for this side input.
+"""
+
+
+WorkerGroupingShuffleRead = build_worker_instruction(
+ 'WorkerGroupingShuffleRead',
+ ['start_shuffle_position', 'end_shuffle_position',
+ 'shuffle_reader_config', 'coder', 'output_coders'])
+"""Worker details needed to read from a grouping shuffle source.
+
+Attributes:
+ start_shuffle_position: An opaque string to be passed to the shuffle
+ source to indicate where to start reading.
+ end_shuffle_position: An opaque string to be passed to the shuffle
+ source to indicate where to stop reading.
+ shuffle_reader_config: An opaque string used to initialize the shuffle
+ reader. Contains things like connection endpoints for the shuffle
+ server appliance and various options.
+ coder: The KV coder used to decode shuffle entries.
+ output_coders: 1-tuple of the coder for the output.
+"""
+
+
+WorkerUngroupedShuffleRead = build_worker_instruction(
+ 'WorkerUngroupedShuffleRead',
+ ['start_shuffle_position', 'end_shuffle_position',
+ 'shuffle_reader_config', 'coder', 'output_coders'])
+"""Worker details needed to read from an ungrouped shuffle source.
+
+Attributes:
+ start_shuffle_position: An opaque string to be passed to the shuffle
+ source to indicate where to start reading.
+ end_shuffle_position: An opaque string to be passed to the shuffle
+ source to indicate where to stop reading.
+ shuffle_reader_config: An opaque string used to initialize the shuffle
+ reader. Contains things like connection endpoints for the shuffle
+ server appliance and various options.
+ coder: The value coder used to decode shuffle entries.
+"""
+
+
+WorkerWrite = build_worker_instruction(
+ 'WorkerWrite', ['sink', 'input', 'output_coders'])
+"""Worker details needed to write to a sink.
+
+Attributes:
+ sink: a sink object.
+ input: A (producer index, output index) tuple representing the
+ ParallelInstruction operation whose output feeds into this operation.
+ The output index is 0 except for multi-output operations (like ParDo).
+ output_coders: 1-tuple, coder to use to estimate bytes written.
+"""
+
+
+WorkerInMemoryWrite = build_worker_instruction(
+ 'WorkerInMemoryWrite',
+ ['output_buffer', 'write_windowed_values', 'input', 'output_coders'])
+"""Worker details needed to write to a in-memory sink.
+
+Used only for unit testing. It makes worker tests less cluttered with code like
+"write to a file and then check file contents".
+
+Attributes:
+ output_buffer: list to which output elements will be appended
+ write_windowed_values: whether to record the entire WindowedValue outputs,
+ or just the raw (unwindowed) value
+ input: A (producer index, output index) tuple representing the
+ ParallelInstruction operation whose output feeds into this operation.
+ The output index is 0 except for multi-output operations (like ParDo).
+ output_coders: 1-tuple, coder to use to estimate bytes written.
+"""
+
+
+WorkerShuffleWrite = build_worker_instruction(
+ 'WorkerShuffleWrite',
+ ['shuffle_kind', 'shuffle_writer_config', 'input', 'output_coders'])
+"""Worker details needed to write to a shuffle sink.
+
+Attributes:
+ shuffle_kind: A string describing the shuffle kind. This can control the
+ way the worker interacts with the shuffle sink. The possible values are:
+ 'ungrouped', 'group_keys', and 'group_keys_and_sort_values'.
+ shuffle_writer_config: An opaque string used to initialize the shuffle
+ write. Contains things like connection endpoints for the shuffle
+ server appliance and various options.
+ input: A (producer index, output index) tuple representing the
+ ParallelInstruction operation whose output feeds into this operation.
+ The output index is 0 except for multi-output operations (like ParDo).
+ output_coders: 1-tuple of the coder for input elements. If the
+ shuffle_kind is grouping, this is expected to be a KV coder.
+"""
+
+
+WorkerDoFn = build_worker_instruction(
+ 'WorkerDoFn',
+ ['serialized_fn', 'output_tags', 'input', 'side_inputs', 'output_coders'])
+"""Worker details needed to run a DoFn.
+Attributes:
+ serialized_fn: A serialized DoFn object to be run for each input element.
+ output_tags: The string tags used to identify the outputs of a ParDo
+ operation. The tag is present even if the ParDo has just one output
+ (e.g., ['out'].
+ output_coders: array of coders, one for each output.
+ input: A (producer index, output index) tuple representing the
+ ParallelInstruction operation whose output feeds into this operation.
+ The output index is 0 except for multi-output operations (like ParDo).
+ side_inputs: A list of Worker...Read instances describing sources to be
+ used for getting values. The types supported right now are
+ WorkerInMemoryRead and WorkerTextRead.
+"""
+
+
+WorkerReifyTimestampAndWindows = build_worker_instruction(
+ 'WorkerReifyTimestampAndWindows',
+ ['output_tags', 'input', 'output_coders'])
+"""Worker details needed to run a WindowInto.
+Attributes:
+ output_tags: The string tags used to identify the outputs of a ParDo
+ operation. The tag is present even if the ParDo has just one output
+ (e.g., ['out'].
+ output_coders: array of coders, one for each output.
+ input: A (producer index, output index) tuple representing the
+ ParallelInstruction operation whose output feeds into this operation.
+ The output index is 0 except for multi-output operations (like ParDo).
+"""
+
+
+WorkerMergeWindows = build_worker_instruction(
+ 'WorkerMergeWindows',
+ ['window_fn', 'combine_fn', 'phase', 'output_tags', 'input', 'coders',
+ 'context', 'output_coders'])
+"""Worker details needed to run a MergeWindows (aka. GroupAlsoByWindows).
+Attributes:
+ window_fn: A serialized Windowing object representing the windowing strategy.
+ combine_fn: A serialized CombineFn object to be used after executing the
+ GroupAlsoByWindows operation. May be None if not a combining operation.
+ phase: Possible values are 'all', 'add', 'merge', and 'extract'.
+ A runner optimizer may split the user combiner in 3 separate
+ phases (ADD, MERGE, and EXTRACT), on separate VMs, as it sees
+ fit. The phase attribute dictates which DoFn is actually running in
+ the worker. May be None if not a combining operation.
+ output_tags: The string tags used to identify the outputs of a ParDo
+ operation. The tag is present even if the ParDo has just one output
+ (e.g., ['out'].
+ output_coders: array of coders, one for each output.
+ input: A (producer index, output index) tuple representing the
+ ParallelInstruction operation whose output feeds into this operation.
+ The output index is 0 except for multi-output operations (like ParDo).
+ coders: A 2-tuple of coders (key, value) to encode shuffle entries.
+ context: The ExecutionContext object for the current work item.
+"""
+
+
+WorkerCombineFn = build_worker_instruction(
+ 'WorkerCombineFn',
+ ['serialized_fn', 'phase', 'input', 'output_coders'])
+"""Worker details needed to run a CombineFn.
+Attributes:
+ serialized_fn: A serialized CombineFn object to be used.
+ phase: Possible values are 'all', 'add', 'merge', and 'extract'.
+ A runner optimizer may split the user combiner in 3 separate
+ phases (ADD, MERGE, and EXTRACT), on separate VMs, as it sees
+ fit. The phase attribute dictates which DoFn is actually running in
+ the worker.
+ input: A (producer index, output index) tuple representing the
+ ParallelInstruction operation whose output feeds into this operation.
+ The output index is 0 except for multi-output operations (like ParDo).
+ output_coders: 1-tuple of the coder for the output.
+"""
+
+
+WorkerPartialGroupByKey = build_worker_instruction(
+ 'WorkerPartialGroupByKey',
+ ['combine_fn', 'input', 'output_coders'])
+"""Worker details needed to run a partial group-by-key.
+Attributes:
+ combine_fn: A serialized CombineFn object to be used.
+ input: A (producer index, output index) tuple representing the
+ ParallelInstruction operation whose output feeds into this operation.
+ The output index is 0 except for multi-output operations (like ParDo).
+ output_coders: 1-tuple of the coder for the output.
+"""
+
+
+WorkerFlatten = build_worker_instruction(
+ 'WorkerFlatten',
+ ['inputs', 'output_coders'])
+"""Worker details needed to run a Flatten.
+Attributes:
+ inputs: A list of tuples, each (producer index, output index), representing
+ the ParallelInstruction operations whose output feeds into this operation.
+ The output index is 0 unless the input is from a multi-output
+ operation (such as ParDo).
+ output_coders: 1-tuple of the coder for the output.
+"""
+
+
+def get_coder_from_spec(coder_spec):
+ """Return a coder instance from a coder spec.
+
+ Args:
+ coder_spec: A dict where the value of the '@type' key is a pickled instance
+ of a Coder instance.
+
+ Returns:
+ A coder instance (has encode/decode methods).
+ """
+ assert coder_spec is not None
+
+ # Ignore the wrappers in these encodings.
+ # TODO(silviuc): Make sure with all the renamings that names below are ok.
+ if coder_spec['@type'] in ignored_wrappers:
+ assert len(coder_spec['component_encodings']) == 1
+ coder_spec = coder_spec['component_encodings'][0]
+ return get_coder_from_spec(coder_spec)
+
+ # Handle a few well known types of coders.
+ if coder_spec['@type'] == 'kind:pair':
+ assert len(coder_spec['component_encodings']) == 2
+ component_coders = [
+ get_coder_from_spec(c) for c in coder_spec['component_encodings']]
+ return coders.TupleCoder(component_coders)
+ elif coder_spec['@type'] == 'kind:stream':
+ assert len(coder_spec['component_encodings']) == 1
+ return coders.IterableCoder(
+ get_coder_from_spec(coder_spec['component_encodings'][0]))
+ elif coder_spec['@type'] == 'kind:windowed_value':
+ assert len(coder_spec['component_encodings']) == 2
+ value_coder, window_coder = [
+ get_coder_from_spec(c) for c in coder_spec['component_encodings']]
+ return coders.WindowedValueCoder(value_coder, window_coder=window_coder)
+ elif coder_spec['@type'] == 'kind:interval_window':
+ assert ('component_encodings' not in coder_spec
+ or len(coder_spec['component_encodings'] == 0))
+ return coders.IntervalWindowCoder()
+ elif coder_spec['@type'] == 'kind:global_window':
+ assert ('component_encodings' not in coder_spec
+ or not coder_spec['component_encodings'])
+ return coders.GlobalWindowCoder()
+ elif coder_spec['@type'] == 'kind:length_prefix':
+ assert len(coder_spec['component_encodings']) == 1
+ return coders.LengthPrefixCoder(
+ get_coder_from_spec(coder_spec['component_encodings'][0]))
+
+ # We pass coders in the form "<coder_name>$<pickled_data>" to make the job
+ # description JSON more readable.
+ return coders.deserialize_coder(coder_spec['@type'])
+
+
+class MapTask(object):
+ """A map task decoded into operations and ready to be executed.
+
+ Attributes:
+ operations: A list of Worker* object created by parsing the instructions
+ within the map task.
+ stage_name: The name of this map task execution stage.
+ system_names: The system names of the step corresponding to each map task
+ operation in the execution graph.
+ step_names: The names of the step corresponding to each map task operation.
+ original_names: The internal name of a step in the original workflow graph.
+ """
+
+ def __init__(
+ self, operations, stage_name, system_names, step_names, original_names):
+ self.operations = operations
+ self.stage_name = stage_name
+ self.system_names = system_names
+ self.step_names = step_names
+ self.original_names = original_names
+
+ def __str__(self):
+ return '<%s %s steps=%s>' % (self.__class__.__name__, self.stage_name,
+ '+'.join(self.step_names))
http://git-wip-us.apache.org/repos/asf/beam/blob/a856fcf3/sdks/python/apache_beam/runners/worker/operations.pxd
----------------------------------------------------------------------
diff --git a/sdks/python/apache_beam/runners/worker/operations.pxd b/sdks/python/apache_beam/runners/worker/operations.pxd
new file mode 100644
index 0000000..2b4e526
--- /dev/null
+++ b/sdks/python/apache_beam/runners/worker/operations.pxd
@@ -0,0 +1,89 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+cimport cython
+
+from apache_beam.runners.common cimport Receiver
+from apache_beam.runners.worker cimport opcounters
+from apache_beam.utils.windowed_value cimport WindowedValue
+from apache_beam.metrics.execution cimport ScopedMetricsContainer
+
+
+cdef WindowedValue _globally_windowed_value
+cdef type _global_window_type
+
+cdef class ConsumerSet(Receiver):
+ cdef list consumers
+ cdef opcounters.OperationCounters opcounter
+ cdef public step_name
+ cdef public output_index
+ cdef public coder
+
+ cpdef receive(self, WindowedValue windowed_value)
+ cpdef update_counters_start(self, WindowedValue windowed_value)
+ cpdef update_counters_finish(self)
+
+
+cdef class Operation(object):
+ cdef readonly operation_name
+ cdef readonly spec
+ cdef object consumers
+ cdef readonly counter_factory
+ cdef public metrics_container
+ cdef public ScopedMetricsContainer scoped_metrics_container
+ # Public for access by Fn harness operations.
+ # TODO(robertwb): Cythonize FnHarness.
+ cdef public list receivers
+ cdef readonly bint debug_logging_enabled
+
+ cdef public step_name # initialized lazily
+
+ cdef readonly object state_sampler
+
+ cdef readonly object scoped_start_state
+ cdef readonly object scoped_process_state
+ cdef readonly object scoped_finish_state
+
+ cpdef start(self)
+ cpdef process(self, WindowedValue windowed_value)
+ cpdef finish(self)
+
+ cpdef output(self, WindowedValue windowed_value, int output_index=*)
+
+cdef class ReadOperation(Operation):
+ @cython.locals(windowed_value=WindowedValue)
+ cpdef start(self)
+
+cdef class DoOperation(Operation):
+ cdef object dofn_runner
+ cdef Receiver dofn_receiver
+
+cdef class CombineOperation(Operation):
+ cdef object phased_combine_fn
+
+cdef class FlattenOperation(Operation):
+ pass
+
+cdef class PGBKCVOperation(Operation):
+ cdef public object combine_fn
+ cdef public object combine_fn_add_input
+ cdef dict table
+ cdef long max_keys
+ cdef long key_count
+
+ cpdef output_key(self, tuple wkey, value)
+
http://git-wip-us.apache.org/repos/asf/beam/blob/a856fcf3/sdks/python/apache_beam/runners/worker/operations.py
----------------------------------------------------------------------
diff --git a/sdks/python/apache_beam/runners/worker/operations.py b/sdks/python/apache_beam/runners/worker/operations.py
new file mode 100644
index 0000000..5dbe57e
--- /dev/null
+++ b/sdks/python/apache_beam/runners/worker/operations.py
@@ -0,0 +1,651 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+# cython: profile=True
+
+"""Worker operations executor."""
+
+import collections
+import itertools
+import logging
+
+from apache_beam import pvalue
+from apache_beam.internal import pickler
+from apache_beam.io import iobase
+from apache_beam.metrics.execution import MetricsContainer
+from apache_beam.metrics.execution import ScopedMetricsContainer
+from apache_beam.runners import common
+from apache_beam.runners.common import Receiver
+from apache_beam.runners.dataflow.internal.names import PropertyNames
+from apache_beam.runners.worker import logger
+from apache_beam.runners.worker import opcounters
+from apache_beam.runners.worker import operation_specs
+from apache_beam.runners.worker import sideinputs
+from apache_beam.transforms import combiners
+from apache_beam.transforms import core
+from apache_beam.transforms import sideinputs as apache_sideinputs
+from apache_beam.transforms.combiners import curry_combine_fn
+from apache_beam.transforms.combiners import PhasedCombineFnExecutor
+from apache_beam.transforms.window import GlobalWindows
+from apache_beam.utils.windowed_value import WindowedValue
+
+# Allow some "pure mode" declarations.
+try:
+ import cython
+except ImportError:
+ class FakeCython(object):
+ @staticmethod
+ def cast(type, value):
+ return value
+ globals()['cython'] = FakeCython()
+
+
+_globally_windowed_value = GlobalWindows.windowed_value(None)
+_global_window_type = type(_globally_windowed_value.windows[0])
+
+
+class ConsumerSet(Receiver):
+ """A ConsumerSet represents a graph edge between two Operation nodes.
+
+ The ConsumerSet object collects information from the output of the
+ Operation at one end of its edge and the input of the Operation at
+ the other edge.
+ ConsumerSet are attached to the outputting Operation.
+ """
+
+ def __init__(
+ self, counter_factory, step_name, output_index, consumers, coder):
+ self.consumers = consumers
+ self.opcounter = opcounters.OperationCounters(
+ counter_factory, step_name, coder, output_index)
+ # Used in repr.
+ self.step_name = step_name
+ self.output_index = output_index
+ self.coder = coder
+
+ def output(self, windowed_value): # For old SDKs.
+ self.receive(windowed_value)
+
+ def receive(self, windowed_value):
+ self.update_counters_start(windowed_value)
+ for consumer in self.consumers:
+ cython.cast(Operation, consumer).process(windowed_value)
+ self.update_counters_finish()
+
+ def update_counters_start(self, windowed_value):
+ self.opcounter.update_from(windowed_value)
+
+ def update_counters_finish(self):
+ self.opcounter.update_collect()
+
+ def __repr__(self):
+ return '%s[%s.out%s, coder=%s, len(consumers)=%s]' % (
+ self.__class__.__name__, self.step_name, self.output_index, self.coder,
+ len(self.consumers))
+
+
+class Operation(object):
+ """An operation representing the live version of a work item specification.
+
+ An operation can have one or more outputs and for each output it can have
+ one or more receiver operations that will take that as input.
+ """
+
+ def __init__(self, operation_name, spec, counter_factory, state_sampler):
+ """Initializes a worker operation instance.
+
+ Args:
+ operation_name: The system name assigned by the runner for this
+ operation.
+ spec: A operation_specs.Worker* instance.
+ counter_factory: The CounterFactory to use for our counters.
+ state_sampler: The StateSampler for the current operation.
+ """
+ self.operation_name = operation_name
+ self.spec = spec
+ self.counter_factory = counter_factory
+ self.consumers = collections.defaultdict(list)
+
+ self.state_sampler = state_sampler
+ self.scoped_start_state = self.state_sampler.scoped_state(
+ self.operation_name + '-start')
+ self.scoped_process_state = self.state_sampler.scoped_state(
+ self.operation_name + '-process')
+ self.scoped_finish_state = self.state_sampler.scoped_state(
+ self.operation_name + '-finish')
+ # TODO(ccy): the '-abort' state can be added when the abort is supported in
+ # Operations.
+
+ def start(self):
+ """Start operation."""
+ self.debug_logging_enabled = logging.getLogger().isEnabledFor(
+ logging.DEBUG)
+ # Everything except WorkerSideInputSource, which is not a
+ # top-level operation, should have output_coders
+ if getattr(self.spec, 'output_coders', None):
+ self.receivers = [ConsumerSet(self.counter_factory, self.step_name,
+ i, self.consumers[i], coder)
+ for i, coder in enumerate(self.spec.output_coders)]
+
+ def finish(self):
+ """Finish operation."""
+ pass
+
+ def process(self, o):
+ """Process element in operation."""
+ pass
+
+ def output(self, windowed_value, output_index=0):
+ cython.cast(Receiver, self.receivers[output_index]).receive(windowed_value)
+
+ def add_receiver(self, operation, output_index=0):
+ """Adds a receiver operation for the specified output."""
+ self.consumers[output_index].append(operation)
+
+ def __str__(self):
+ """Generates a useful string for this object.
+
+ Compactly displays interesting fields. In particular, pickled
+ fields are not displayed. Note that we collapse the fields of the
+ contained Worker* object into this object, since there is a 1-1
+ mapping between Operation and operation_specs.Worker*.
+
+ Returns:
+ Compact string representing this object.
+ """
+ return self.str_internal()
+
+ def str_internal(self, is_recursive=False):
+ """Internal helper for __str__ that supports recursion.
+
+ When recursing on receivers, keep the output short.
+ Args:
+ is_recursive: whether to omit some details, particularly receivers.
+ Returns:
+ Compact string representing this object.
+ """
+ printable_name = self.__class__.__name__
+ if hasattr(self, 'step_name'):
+ printable_name += ' %s' % self.step_name
+ if is_recursive:
+ # If we have a step name, stop here, no more detail needed.
+ return '<%s>' % printable_name
+
+ if self.spec is None:
+ printable_fields = []
+ else:
+ printable_fields = operation_specs.worker_printable_fields(self.spec)
+
+ if not is_recursive and getattr(self, 'receivers', []):
+ printable_fields.append('receivers=[%s]' % ', '.join([
+ str(receiver) for receiver in self.receivers]))
+
+ return '<%s %s>' % (printable_name, ', '.join(printable_fields))
+
+
+class ReadOperation(Operation):
+
+ def start(self):
+ with self.scoped_start_state:
+ super(ReadOperation, self).start()
+ range_tracker = self.spec.source.source.get_range_tracker(
+ self.spec.source.start_position, self.spec.source.stop_position)
+ for value in self.spec.source.source.read(range_tracker):
+ if isinstance(value, WindowedValue):
+ windowed_value = value
+ else:
+ windowed_value = _globally_windowed_value.with_value(value)
+ self.output(windowed_value)
+
+
+class InMemoryWriteOperation(Operation):
+ """A write operation that will write to an in-memory sink."""
+
+ def process(self, o):
+ with self.scoped_process_state:
+ if self.debug_logging_enabled:
+ logging.debug('Processing [%s] in %s', o, self)
+ self.spec.output_buffer.append(
+ o if self.spec.write_windowed_values else o.value)
+
+
+class _TaggedReceivers(dict):
+
+ class NullReceiver(Receiver):
+
+ def receive(self, element):
+ pass
+
+ # For old SDKs.
+ def output(self, element):
+ pass
+
+ def __missing__(self, unused_key):
+ if not getattr(self, '_null_receiver', None):
+ self._null_receiver = _TaggedReceivers.NullReceiver()
+ return self._null_receiver
+
+
+class DoOperation(Operation):
+ """A Do operation that will execute a custom DoFn for each input element."""
+
+ def _read_side_inputs(self, tags_and_types):
+ """Generator reading side inputs in the order prescribed by tags_and_types.
+
+ Args:
+ tags_and_types: List of tuples (tag, type). Each side input has a string
+ tag that is specified in the worker instruction. The type is actually
+ a boolean which is True for singleton input (read just first value)
+ and False for collection input (read all values).
+
+ Yields:
+ With each iteration it yields the result of reading an entire side source
+ either in singleton or collection mode according to the tags_and_types
+ argument.
+ """
+ # We will read the side inputs in the order prescribed by the
+ # tags_and_types argument because this is exactly the order needed to
+ # replace the ArgumentPlaceholder objects in the args/kwargs of the DoFn
+ # getting the side inputs.
+ #
+ # Note that for each tag there could be several read operations in the
+ # specification. This can happen for instance if the source has been
+ # sharded into several files.
+ for side_tag, view_class, view_options in tags_and_types:
+ sources = []
+ # Using the side_tag in the lambda below will trigger a pylint warning.
+ # However in this case it is fine because the lambda is used right away
+ # while the variable has the value assigned by the current iteration of
+ # the for loop.
+ # pylint: disable=cell-var-from-loop
+ for si in itertools.ifilter(
+ lambda o: o.tag == side_tag, self.spec.side_inputs):
+ if not isinstance(si, operation_specs.WorkerSideInputSource):
+ raise NotImplementedError('Unknown side input type: %r' % si)
+ sources.append(si.source)
+ iterator_fn = sideinputs.get_iterator_fn_for_sources(sources)
+
+ # Backwards compatibility for pre BEAM-733 SDKs.
+ if isinstance(view_options, tuple):
+ if view_class == pvalue.SingletonPCollectionView:
+ has_default, default = view_options
+ view_options = {'default': default} if has_default else {}
+ else:
+ view_options = {}
+
+ yield apache_sideinputs.SideInputMap(
+ view_class, view_options, sideinputs.EmulatedIterable(iterator_fn))
+
+ def start(self):
+ with self.scoped_start_state:
+ super(DoOperation, self).start()
+
+ # See fn_data in dataflow_runner.py
+ fn, args, kwargs, tags_and_types, window_fn = (
+ pickler.loads(self.spec.serialized_fn))
+
+ state = common.DoFnState(self.counter_factory)
+ state.step_name = self.step_name
+
+ # TODO(silviuc): What is the proper label here? PCollection being
+ # processed?
+ context = common.DoFnContext('label', state=state)
+ # Tag to output index map used to dispatch the side output values emitted
+ # by the DoFn function to the appropriate receivers. The main output is
+ # tagged with None and is associated with its corresponding index.
+ tagged_receivers = _TaggedReceivers()
+
+ output_tag_prefix = PropertyNames.OUT + '_'
+ for index, tag in enumerate(self.spec.output_tags):
+ if tag == PropertyNames.OUT:
+ original_tag = None
+ elif tag.startswith(output_tag_prefix):
+ original_tag = tag[len(output_tag_prefix):]
+ else:
+ raise ValueError('Unexpected output name for operation: %s' % tag)
+ tagged_receivers[original_tag] = self.receivers[index]
+
+ self.dofn_runner = common.DoFnRunner(
+ fn, args, kwargs, self._read_side_inputs(tags_and_types),
+ window_fn, context, tagged_receivers,
+ logger, self.step_name,
+ scoped_metrics_container=self.scoped_metrics_container)
+ self.dofn_receiver = (self.dofn_runner
+ if isinstance(self.dofn_runner, Receiver)
+ else DoFnRunnerReceiver(self.dofn_runner))
+
+ self.dofn_runner.start()
+
+ def finish(self):
+ with self.scoped_finish_state:
+ self.dofn_runner.finish()
+
+ def process(self, o):
+ with self.scoped_process_state:
+ self.dofn_receiver.receive(o)
+
+
+class DoFnRunnerReceiver(Receiver):
+
+ def __init__(self, dofn_runner):
+ self.dofn_runner = dofn_runner
+
+ def receive(self, windowed_value):
+ self.dofn_runner.process(windowed_value)
+
+
+class CombineOperation(Operation):
+ """A Combine operation executing a CombineFn for each input element."""
+
+ def __init__(self, operation_name, spec, counter_factory, state_sampler):
+ super(CombineOperation, self).__init__(
+ operation_name, spec, counter_factory, state_sampler)
+ # Combiners do not accept deferred side-inputs (the ignored fourth argument)
+ # and therefore the code to handle the extra args/kwargs is simpler than for
+ # the DoFn's of ParDo.
+ fn, args, kwargs = pickler.loads(self.spec.serialized_fn)[:3]
+ self.phased_combine_fn = (
+ PhasedCombineFnExecutor(self.spec.phase, fn, args, kwargs))
+
+ def finish(self):
+ logging.debug('Finishing %s', self)
+
+ def process(self, o):
+ if self.debug_logging_enabled:
+ logging.debug('Processing [%s] in %s', o, self)
+ key, values = o.value
+ with self.scoped_metrics_container:
+ self.output(
+ o.with_value((key, self.phased_combine_fn.apply(values))))
+
+
+def create_pgbk_op(step_name, spec, counter_factory, state_sampler):
+ if spec.combine_fn:
+ return PGBKCVOperation(step_name, spec, counter_factory, state_sampler)
+ else:
+ return PGBKOperation(step_name, spec, counter_factory, state_sampler)
+
+
+class PGBKOperation(Operation):
+ """Partial group-by-key operation.
+
+ This takes (windowed) input (key, value) tuples and outputs
+ (key, [value]) tuples, performing a best effort group-by-key for
+ values in this bundle, memory permitting.
+ """
+
+ def __init__(self, operation_name, spec, counter_factory, state_sampler):
+ super(PGBKOperation, self).__init__(
+ operation_name, spec, counter_factory, state_sampler)
+ assert not self.spec.combine_fn
+ self.table = collections.defaultdict(list)
+ self.size = 0
+ # TODO(robertwb) Make this configurable.
+ self.max_size = 10 * 1000
+
+ def process(self, o):
+ # TODO(robertwb): Structural (hashable) values.
+ key = o.value[0], tuple(o.windows)
+ self.table[key].append(o)
+ self.size += 1
+ if self.size > self.max_size:
+ self.flush(9 * self.max_size // 10)
+
+ def finish(self):
+ self.flush(0)
+
+ def flush(self, target):
+ limit = self.size - target
+ for ix, (kw, vs) in enumerate(self.table.items()):
+ if ix >= limit:
+ break
+ del self.table[kw]
+ key, windows = kw
+ output_value = [v.value[1] for v in vs]
+ windowed_value = WindowedValue(
+ (key, output_value),
+ vs[0].timestamp, windows)
+ self.output(windowed_value)
+
+
+class PGBKCVOperation(Operation):
+
+ def __init__(self, operation_name, spec, counter_factory, state_sampler):
+ super(PGBKCVOperation, self).__init__(
+ operation_name, spec, counter_factory, state_sampler)
+ # Combiners do not accept deferred side-inputs (the ignored fourth
+ # argument) and therefore the code to handle the extra args/kwargs is
+ # simpler than for the DoFn's of ParDo.
+ fn, args, kwargs = pickler.loads(self.spec.combine_fn)[:3]
+ self.combine_fn = curry_combine_fn(fn, args, kwargs)
+ if (getattr(fn.add_input, 'im_func', None)
+ is core.CombineFn.add_input.im_func):
+ # Old versions of the SDK have CombineFns that don't implement add_input.
+ self.combine_fn_add_input = (
+ lambda a, e: self.combine_fn.add_inputs(a, [e]))
+ else:
+ self.combine_fn_add_input = self.combine_fn.add_input
+ # Optimization for the (known tiny accumulator, often wide keyspace)
+ # combine functions.
+ # TODO(b/36567833): Bound by in-memory size rather than key count.
+ self.max_keys = (
+ 1000 * 1000 if
+ isinstance(fn, (combiners.CountCombineFn, combiners.MeanCombineFn)) or
+ # TODO(b/36597732): Replace this 'or' part by adding the 'cy' optimized
+ # combiners to the short list above.
+ (isinstance(fn, core.CallableWrapperCombineFn) and
+ fn._fn in (min, max, sum)) else 100 * 1000) # pylint: disable=protected-access
+ self.key_count = 0
+ self.table = {}
+
+ def process(self, wkv):
+ key, value = wkv.value
+ # pylint: disable=unidiomatic-typecheck
+ # Optimization for the global window case.
+ if len(wkv.windows) == 1 and type(wkv.windows[0]) is _global_window_type:
+ wkey = 0, key
+ else:
+ wkey = tuple(wkv.windows), key
+ entry = self.table.get(wkey, None)
+ if entry is None:
+ if self.key_count >= self.max_keys:
+ target = self.key_count * 9 // 10
+ old_wkeys = []
+ # TODO(robertwb): Use an LRU cache?
+ for old_wkey, old_wvalue in self.table.iteritems():
+ old_wkeys.append(old_wkey) # Can't mutate while iterating.
+ self.output_key(old_wkey, old_wvalue[0])
+ self.key_count -= 1
+ if self.key_count <= target:
+ break
+ for old_wkey in reversed(old_wkeys):
+ del self.table[old_wkey]
+ self.key_count += 1
+ # We save the accumulator as a one element list so we can efficiently
+ # mutate when new values are added without searching the cache again.
+ entry = self.table[wkey] = [self.combine_fn.create_accumulator()]
+ entry[0] = self.combine_fn_add_input(entry[0], value)
+
+ def finish(self):
+ for wkey, value in self.table.iteritems():
+ self.output_key(wkey, value[0])
+ self.table = {}
+ self.key_count = 0
+
+ def output_key(self, wkey, value):
+ windows, key = wkey
+ if windows is 0:
+ self.output(_globally_windowed_value.with_value((key, value)))
+ else:
+ self.output(WindowedValue((key, value), windows[0].end, windows))
+
+
+class FlattenOperation(Operation):
+ """Flatten operation.
+
+ Receives one or more producer operations, outputs just one list
+ with all the items.
+ """
+
+ def process(self, o):
+ if self.debug_logging_enabled:
+ logging.debug('Processing [%s] in %s', o, self)
+ self.output(o)
+
+
+def create_operation(operation_name, spec, counter_factory, step_name,
+ state_sampler, test_shuffle_source=None,
+ test_shuffle_sink=None, is_streaming=False):
+ """Create Operation object for given operation specification."""
+ if isinstance(spec, operation_specs.WorkerRead):
+ if isinstance(spec.source, iobase.SourceBundle):
+ op = ReadOperation(
+ operation_name, spec, counter_factory, state_sampler)
+ else:
+ from dataflow_worker.native_operations import NativeReadOperation
+ op = NativeReadOperation(
+ operation_name, spec, counter_factory, state_sampler)
+ elif isinstance(spec, operation_specs.WorkerWrite):
+ from dataflow_worker.native_operations import NativeWriteOperation
+ op = NativeWriteOperation(
+ operation_name, spec, counter_factory, state_sampler)
+ elif isinstance(spec, operation_specs.WorkerCombineFn):
+ op = CombineOperation(
+ operation_name, spec, counter_factory, state_sampler)
+ elif isinstance(spec, operation_specs.WorkerPartialGroupByKey):
+ op = create_pgbk_op(operation_name, spec, counter_factory, state_sampler)
+ elif isinstance(spec, operation_specs.WorkerDoFn):
+ op = DoOperation(operation_name, spec, counter_factory, state_sampler)
+ elif isinstance(spec, operation_specs.WorkerGroupingShuffleRead):
+ from dataflow_worker.shuffle_operations import GroupedShuffleReadOperation
+ op = GroupedShuffleReadOperation(
+ operation_name, spec, counter_factory, state_sampler,
+ shuffle_source=test_shuffle_source)
+ elif isinstance(spec, operation_specs.WorkerUngroupedShuffleRead):
+ from dataflow_worker.shuffle_operations import UngroupedShuffleReadOperation
+ op = UngroupedShuffleReadOperation(
+ operation_name, spec, counter_factory, state_sampler,
+ shuffle_source=test_shuffle_source)
+ elif isinstance(spec, operation_specs.WorkerInMemoryWrite):
+ op = InMemoryWriteOperation(
+ operation_name, spec, counter_factory, state_sampler)
+ elif isinstance(spec, operation_specs.WorkerShuffleWrite):
+ from dataflow_worker.shuffle_operations import ShuffleWriteOperation
+ op = ShuffleWriteOperation(
+ operation_name, spec, counter_factory, state_sampler,
+ shuffle_sink=test_shuffle_sink)
+ elif isinstance(spec, operation_specs.WorkerFlatten):
+ op = FlattenOperation(
+ operation_name, spec, counter_factory, state_sampler)
+ elif isinstance(spec, operation_specs.WorkerMergeWindows):
+ from dataflow_worker.shuffle_operations import BatchGroupAlsoByWindowsOperation
+ from dataflow_worker.shuffle_operations import StreamingGroupAlsoByWindowsOperation
+ if is_streaming:
+ op = StreamingGroupAlsoByWindowsOperation(
+ operation_name, spec, counter_factory, state_sampler)
+ else:
+ op = BatchGroupAlsoByWindowsOperation(
+ operation_name, spec, counter_factory, state_sampler)
+ elif isinstance(spec, operation_specs.WorkerReifyTimestampAndWindows):
+ from dataflow_worker.shuffle_operations import ReifyTimestampAndWindowsOperation
+ op = ReifyTimestampAndWindowsOperation(
+ operation_name, spec, counter_factory, state_sampler)
+ else:
+ raise TypeError('Expected an instance of operation_specs.Worker* class '
+ 'instead of %s' % (spec,))
+ op.step_name = step_name
+ op.metrics_container = MetricsContainer(step_name)
+ op.scoped_metrics_container = ScopedMetricsContainer(op.metrics_container)
+ return op
+
+
+class SimpleMapTaskExecutor(object):
+ """An executor for map tasks.
+
+ Stores progress of the read operation that is the first operation of a map
+ task.
+ """
+
+ def __init__(
+ self, map_task, counter_factory, state_sampler,
+ test_shuffle_source=None, test_shuffle_sink=None):
+ """Initializes SimpleMapTaskExecutor.
+
+ Args:
+ map_task: The map task we are to run.
+ counter_factory: The CounterFactory instance for the work item.
+ state_sampler: The StateSampler tracking the execution step.
+ test_shuffle_source: Used during tests for dependency injection into
+ shuffle read operation objects.
+ test_shuffle_sink: Used during tests for dependency injection into
+ shuffle write operation objects.
+ """
+
+ self._map_task = map_task
+ self._counter_factory = counter_factory
+ self._ops = []
+ self._state_sampler = state_sampler
+ self._test_shuffle_source = test_shuffle_source
+ self._test_shuffle_sink = test_shuffle_sink
+
+ def operations(self):
+ return self._ops[:]
+
+ def execute(self):
+ """Executes all the operation_specs.Worker* instructions in a map task.
+
+ We update the map_task with the execution status, expressed as counters.
+
+ Raises:
+ RuntimeError: if we find more than on read instruction in task spec.
+ TypeError: if the spec parameter is not an instance of the recognized
+ operation_specs.Worker* classes.
+ """
+
+ # operations is a list of operation_specs.Worker* instances.
+ # The order of the elements is important because the inputs use
+ # list indexes as references.
+
+ step_names = (
+ self._map_task.step_names or [None] * len(self._map_task.operations))
+ for ix, spec in enumerate(self._map_task.operations):
+ # This is used for logging and assigning names to counters.
+ operation_name = self._map_task.system_names[ix]
+ step_name = step_names[ix]
+ op = create_operation(
+ operation_name, spec, self._counter_factory, step_name,
+ self._state_sampler,
+ test_shuffle_source=self._test_shuffle_source,
+ test_shuffle_sink=self._test_shuffle_sink)
+ self._ops.append(op)
+
+ # Add receiver operations to the appropriate producers.
+ if hasattr(op.spec, 'input'):
+ producer, output_index = op.spec.input
+ self._ops[producer].add_receiver(op, output_index)
+ # Flatten has 'inputs', not 'input'
+ if hasattr(op.spec, 'inputs'):
+ for producer, output_index in op.spec.inputs:
+ self._ops[producer].add_receiver(op, output_index)
+
+ for ix, op in reversed(list(enumerate(self._ops))):
+ logging.debug('Starting op %d %s', ix, op)
+ with op.scoped_metrics_container:
+ op.start()
+ for op in self._ops:
+ with op.scoped_metrics_container:
+ op.finish()
http://git-wip-us.apache.org/repos/asf/beam/blob/a856fcf3/sdks/python/apache_beam/runners/worker/sdk_worker.py
----------------------------------------------------------------------
diff --git a/sdks/python/apache_beam/runners/worker/sdk_worker.py b/sdks/python/apache_beam/runners/worker/sdk_worker.py
new file mode 100644
index 0000000..6907f6e
--- /dev/null
+++ b/sdks/python/apache_beam/runners/worker/sdk_worker.py
@@ -0,0 +1,451 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""SDK harness for executing Python Fns via the Fn API."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import collections
+import json
+import logging
+import Queue as queue
+import threading
+import traceback
+import zlib
+
+import dill
+from google.protobuf import wrappers_pb2
+
+from apache_beam.coders import coder_impl
+from apache_beam.coders import WindowedValueCoder
+from apache_beam.internal import pickler
+from apache_beam.runners.dataflow.native_io import iobase
+from apache_beam.utils import counters
+from apache_beam.runners.api import beam_fn_api_pb2
+from apache_beam.runners.worker import operation_specs
+from apache_beam.runners.worker import operations
+try:
+ from apache_beam.runners.worker import statesampler
+except ImportError:
+ from apache_beam.runners.worker import statesampler_fake as statesampler
+from apache_beam.runners.worker.data_plane import GrpcClientDataChannelFactory
+
+
+DATA_INPUT_URN = 'urn:org.apache.beam:source:runner:0.1'
+DATA_OUTPUT_URN = 'urn:org.apache.beam:sink:runner:0.1'
+IDENTITY_DOFN_URN = 'urn:org.apache.beam:dofn:identity:0.1'
+PYTHON_ITERABLE_VIEWFN_URN = 'urn:org.apache.beam:viewfn:iterable:python:0.1'
+PYTHON_CODER_URN = 'urn:org.apache.beam:coder:python:0.1'
+# TODO(vikasrk): Fix this once runner sends appropriate python urns.
+PYTHON_DOFN_URN = 'urn:org.apache.beam:dofn:java:0.1'
+PYTHON_SOURCE_URN = 'urn:org.apache.beam:source:java:0.1'
+
+
+class RunnerIOOperation(operations.Operation):
+ """Common baseclass for runner harness IO operations."""
+
+ def __init__(self, operation_name, step_name, consumers, counter_factory,
+ state_sampler, windowed_coder, target, data_channel):
+ super(RunnerIOOperation, self).__init__(
+ operation_name, None, counter_factory, state_sampler)
+ self.windowed_coder = windowed_coder
+ self.step_name = step_name
+ # target represents the consumer for the bytes in the data plane for a
+ # DataInputOperation or a producer of these bytes for a DataOutputOperation.
+ self.target = target
+ self.data_channel = data_channel
+ for _, consumer_ops in consumers.items():
+ for consumer in consumer_ops:
+ self.add_receiver(consumer, 0)
+
+
+class DataOutputOperation(RunnerIOOperation):
+ """A sink-like operation that gathers outputs to be sent back to the runner.
+ """
+
+ def set_output_stream(self, output_stream):
+ self.output_stream = output_stream
+
+ def process(self, windowed_value):
+ self.windowed_coder.get_impl().encode_to_stream(
+ windowed_value, self.output_stream, True)
+
+ def finish(self):
+ self.output_stream.close()
+ super(DataOutputOperation, self).finish()
+
+
+class DataInputOperation(RunnerIOOperation):
+ """A source-like operation that gathers input from the runner.
+ """
+
+ def __init__(self, operation_name, step_name, consumers, counter_factory,
+ state_sampler, windowed_coder, input_target, data_channel):
+ super(DataInputOperation, self).__init__(
+ operation_name, step_name, consumers, counter_factory, state_sampler,
+ windowed_coder, target=input_target, data_channel=data_channel)
+ # We must do this manually as we don't have a spec or spec.output_coders.
+ self.receivers = [
+ operations.ConsumerSet(self.counter_factory, self.step_name, 0,
+ consumers.itervalues().next(),
+ self.windowed_coder)]
+
+ def process(self, windowed_value):
+ self.output(windowed_value)
+
+ def process_encoded(self, encoded_windowed_values):
+ input_stream = coder_impl.create_InputStream(encoded_windowed_values)
+ while input_stream.size() > 0:
+ decoded_value = self.windowed_coder.get_impl().decode_from_stream(
+ input_stream, True)
+ self.output(decoded_value)
+
+
+# TODO(robertwb): Revise side input API to not be in terms of native sources.
+# This will enable lookups, but there's an open question as to how to handle
+# custom sources without forcing intermediate materialization. This seems very
+# related to the desire to inject key and window preserving [Splittable]DoFns
+# into the view computation.
+class SideInputSource(iobase.NativeSource, iobase.NativeSourceReader):
+ """A 'source' for reading side inputs via state API calls.
+ """
+
+ def __init__(self, state_handler, state_key, coder):
+ self._state_handler = state_handler
+ self._state_key = state_key
+ self._coder = coder
+
+ def reader(self):
+ return self
+
+ @property
+ def returns_windowed_values(self):
+ return True
+
+ def __enter__(self):
+ return self
+
+ def __exit__(self, *exn_info):
+ pass
+
+ def __iter__(self):
+ # TODO(robertwb): Support pagination.
+ input_stream = coder_impl.create_InputStream(
+ self._state_handler.Get(self._state_key).data)
+ while input_stream.size() > 0:
+ yield self._coder.get_impl().decode_from_stream(input_stream, True)
+
+
+def unpack_and_deserialize_py_fn(function_spec):
+ """Returns unpacked and deserialized object from function spec proto."""
+ return pickler.loads(unpack_function_spec_data(function_spec))
+
+
+def unpack_function_spec_data(function_spec):
+ """Returns unpacked data from function spec proto."""
+ data = wrappers_pb2.BytesValue()
+ function_spec.data.Unpack(data)
+ return data.value
+
+
+# pylint: disable=redefined-builtin
+def serialize_and_pack_py_fn(fn, urn, id=None):
+ """Returns serialized and packed function in a function spec proto."""
+ return pack_function_spec_data(pickler.dumps(fn), urn, id)
+# pylint: enable=redefined-builtin
+
+
+# pylint: disable=redefined-builtin
+def pack_function_spec_data(value, urn, id=None):
+ """Returns packed data in a function spec proto."""
+ data = wrappers_pb2.BytesValue(value=value)
+ fn_proto = beam_fn_api_pb2.FunctionSpec(urn=urn)
+ fn_proto.data.Pack(data)
+ if id:
+ fn_proto.id = id
+ return fn_proto
+# pylint: enable=redefined-builtin
+
+
+# TODO(vikasrk): move this method to ``coders.py`` in the SDK.
+def load_compressed(compressed_data):
+ """Returns a decompressed and deserialized python object."""
+ # Note: SDK uses ``pickler.dumps`` to serialize certain python objects
+ # (like sources), which involves serialization, compression and base64
+ # encoding. We cannot directly use ``pickler.loads`` for
+ # deserialization, as the runner would have already base64 decoded the
+ # data. So we only need to decompress and deserialize.
+
+ data = zlib.decompress(compressed_data)
+ try:
+ return dill.loads(data)
+ except Exception: # pylint: disable=broad-except
+ dill.dill._trace(True) # pylint: disable=protected-access
+ return dill.loads(data)
+ finally:
+ dill.dill._trace(False) # pylint: disable=protected-access
+
+
+class SdkHarness(object):
+
+ def __init__(self, control_channel):
+ self._control_channel = control_channel
+ self._data_channel_factory = GrpcClientDataChannelFactory()
+
+ def run(self):
+ contol_stub = beam_fn_api_pb2.BeamFnControlStub(self._control_channel)
+ # TODO(robertwb): Wire up to new state api.
+ state_stub = None
+ self.worker = SdkWorker(state_stub, self._data_channel_factory)
+
+ responses = queue.Queue()
+ no_more_work = object()
+
+ def get_responses():
+ while True:
+ response = responses.get()
+ if response is no_more_work:
+ return
+ yield response
+
+ def process_requests():
+ for work_request in contol_stub.Control(get_responses()):
+ logging.info('Got work %s', work_request.instruction_id)
+ try:
+ response = self.worker.do_instruction(work_request)
+ except Exception: # pylint: disable=broad-except
+ response = beam_fn_api_pb2.InstructionResponse(
+ instruction_id=work_request.instruction_id,
+ error=traceback.format_exc())
+ responses.put(response)
+ t = threading.Thread(target=process_requests)
+ t.start()
+ t.join()
+ # get_responses may be blocked on responses.get(), but we need to return
+ # control to its caller.
+ responses.put(no_more_work)
+ self._data_channel_factory.close()
+ logging.info('Done consuming work.')
+
+
+class SdkWorker(object):
+
+ def __init__(self, state_handler, data_channel_factory):
+ self.fns = {}
+ self.state_handler = state_handler
+ self.data_channel_factory = data_channel_factory
+
+ def do_instruction(self, request):
+ request_type = request.WhichOneof('request')
+ if request_type:
+ # E.g. if register is set, this will construct
+ # InstructionResponse(register=self.register(request.register))
+ return beam_fn_api_pb2.InstructionResponse(**{
+ 'instruction_id': request.instruction_id,
+ request_type: getattr(self, request_type)
+ (getattr(request, request_type), request.instruction_id)
+ })
+ else:
+ raise NotImplementedError
+
+ def register(self, request, unused_instruction_id=None):
+ for process_bundle_descriptor in request.process_bundle_descriptor:
+ self.fns[process_bundle_descriptor.id] = process_bundle_descriptor
+ for p_transform in list(process_bundle_descriptor.primitive_transform):
+ self.fns[p_transform.function_spec.id] = p_transform.function_spec
+ return beam_fn_api_pb2.RegisterResponse()
+
+ def initial_source_split(self, request, unused_instruction_id=None):
+ source_spec = self.fns[request.source_reference]
+ assert source_spec.urn == PYTHON_SOURCE_URN
+ source_bundle = unpack_and_deserialize_py_fn(
+ self.fns[request.source_reference])
+ splits = source_bundle.source.split(request.desired_bundle_size_bytes,
+ source_bundle.start_position,
+ source_bundle.stop_position)
+ response = beam_fn_api_pb2.InitialSourceSplitResponse()
+ response.splits.extend([
+ beam_fn_api_pb2.SourceSplit(
+ source=serialize_and_pack_py_fn(split, PYTHON_SOURCE_URN),
+ relative_size=split.weight,
+ )
+ for split in splits
+ ])
+ return response
+
+ def create_execution_tree(self, descriptor):
+ # TODO(vikasrk): Add an id field to Coder proto and use that instead.
+ coders = {coder.function_spec.id: operation_specs.get_coder_from_spec(
+ json.loads(unpack_function_spec_data(coder.function_spec)))
+ for coder in descriptor.coders}
+
+ counter_factory = counters.CounterFactory()
+ # TODO(robertwb): Figure out the correct prefix to use for output counters
+ # from StateSampler.
+ state_sampler = statesampler.StateSampler(
+ 'fnapi-step%s-' % descriptor.id, counter_factory)
+ consumers = collections.defaultdict(lambda: collections.defaultdict(list))
+ ops_by_id = {}
+ reversed_ops = []
+
+ for transform in reversed(descriptor.primitive_transform):
+ # TODO(robertwb): Figure out how to plumb through the operation name (e.g.
+ # "s3") from the service through the FnAPI so that msec counters can be
+ # reported and correctly plumbed through the service and the UI.
+ operation_name = 'fnapis%s' % transform.id
+
+ def only_element(iterable):
+ element, = iterable
+ return element
+
+ if transform.function_spec.urn == DATA_OUTPUT_URN:
+ target = beam_fn_api_pb2.Target(
+ primitive_transform_reference=transform.id,
+ name=only_element(transform.outputs.keys()))
+
+ op = DataOutputOperation(
+ operation_name,
+ transform.step_name,
+ consumers[transform.id],
+ counter_factory,
+ state_sampler,
+ coders[only_element(transform.outputs.values()).coder_reference],
+ target,
+ self.data_channel_factory.create_data_channel(
+ transform.function_spec))
+
+ elif transform.function_spec.urn == DATA_INPUT_URN:
+ target = beam_fn_api_pb2.Target(
+ primitive_transform_reference=transform.id,
+ name=only_element(transform.inputs.keys()))
+ op = DataInputOperation(
+ operation_name,
+ transform.step_name,
+ consumers[transform.id],
+ counter_factory,
+ state_sampler,
+ coders[only_element(transform.outputs.values()).coder_reference],
+ target,
+ self.data_channel_factory.create_data_channel(
+ transform.function_spec))
+
+ elif transform.function_spec.urn == PYTHON_DOFN_URN:
+ def create_side_input(tag, si):
+ # TODO(robertwb): Extract windows (and keys) out of element data.
+ return operation_specs.WorkerSideInputSource(
+ tag=tag,
+ source=SideInputSource(
+ self.state_handler,
+ beam_fn_api_pb2.StateKey(
+ function_spec_reference=si.view_fn.id),
+ coder=unpack_and_deserialize_py_fn(si.view_fn)))
+ output_tags = list(transform.outputs.keys())
+ spec = operation_specs.WorkerDoFn(
+ serialized_fn=unpack_function_spec_data(transform.function_spec),
+ output_tags=output_tags,
+ input=None,
+ side_inputs=[create_side_input(tag, si)
+ for tag, si in transform.side_inputs.items()],
+ output_coders=[coders[transform.outputs[out].coder_reference]
+ for out in output_tags])
+
+ op = operations.DoOperation(operation_name, spec, counter_factory,
+ state_sampler)
+ # TODO(robertwb): Move these to the constructor.
+ op.step_name = transform.step_name
+ for tag, op_consumers in consumers[transform.id].items():
+ for consumer in op_consumers:
+ op.add_receiver(
+ consumer, output_tags.index(tag))
+
+ elif transform.function_spec.urn == IDENTITY_DOFN_URN:
+ op = operations.FlattenOperation(operation_name, None, counter_factory,
+ state_sampler)
+ # TODO(robertwb): Move these to the constructor.
+ op.step_name = transform.step_name
+ for tag, op_consumers in consumers[transform.id].items():
+ for consumer in op_consumers:
+ op.add_receiver(consumer, 0)
+
+ elif transform.function_spec.urn == PYTHON_SOURCE_URN:
+ source = load_compressed(unpack_function_spec_data(
+ transform.function_spec))
+ # TODO(vikasrk): Remove this once custom source is implemented with
+ # splittable dofn via the data plane.
+ spec = operation_specs.WorkerRead(
+ iobase.SourceBundle(1.0, source, None, None),
+ [WindowedValueCoder(source.default_output_coder())])
+ op = operations.ReadOperation(operation_name, spec, counter_factory,
+ state_sampler)
+ op.step_name = transform.step_name
+ output_tags = list(transform.outputs.keys())
+ for tag, op_consumers in consumers[transform.id].items():
+ for consumer in op_consumers:
+ op.add_receiver(
+ consumer, output_tags.index(tag))
+
+ else:
+ raise NotImplementedError
+
+ # Record consumers.
+ for _, inputs in transform.inputs.items():
+ for target in inputs.target:
+ consumers[target.primitive_transform_reference][target.name].append(
+ op)
+
+ reversed_ops.append(op)
+ ops_by_id[transform.id] = op
+
+ return list(reversed(reversed_ops)), ops_by_id
+
+ def process_bundle(self, request, instruction_id):
+ ops, ops_by_id = self.create_execution_tree(
+ self.fns[request.process_bundle_descriptor_reference])
+
+ expected_inputs = []
+ for _, op in ops_by_id.items():
+ if isinstance(op, DataOutputOperation):
+ # TODO(robertwb): Is there a better way to pass the instruction id to
+ # the operation?
+ op.set_output_stream(op.data_channel.output_stream(
+ instruction_id, op.target))
+ elif isinstance(op, DataInputOperation):
+ # We must wait until we receive "end of stream" for each of these ops.
+ expected_inputs.append(op)
+
+ # Start all operations.
+ for op in reversed(ops):
+ logging.info('start %s', op)
+ op.start()
+
+ # Inject inputs from data plane.
+ for input_op in expected_inputs:
+ for data in input_op.data_channel.input_elements(
+ instruction_id, [input_op.target]):
+ # ignores input name
+ target_op = ops_by_id[data.target.primitive_transform_reference]
+ # lacks coder for non-input ops
+ target_op.process_encoded(data.data)
+
+ # Finish all operations.
+ for op in ops:
+ logging.info('finish %s', op)
+ op.finish()
+
+ return beam_fn_api_pb2.ProcessBundleResponse()
http://git-wip-us.apache.org/repos/asf/beam/blob/a856fcf3/sdks/python/apache_beam/runners/worker/sdk_worker_main.py
----------------------------------------------------------------------
diff --git a/sdks/python/apache_beam/runners/worker/sdk_worker_main.py b/sdks/python/apache_beam/runners/worker/sdk_worker_main.py
new file mode 100644
index 0000000..28828c3
--- /dev/null
+++ b/sdks/python/apache_beam/runners/worker/sdk_worker_main.py
@@ -0,0 +1,62 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""SDK Fn Harness entry point."""
+
+import logging
+import os
+import sys
+
+import grpc
+from google.protobuf import text_format
+
+from apache_beam.runners.api import beam_fn_api_pb2
+from apache_beam.runners.worker.log_handler import FnApiLogRecordHandler
+from apache_beam.runners.worker.sdk_worker import SdkHarness
+
+
+def main(unused_argv):
+ """Main entry point for SDK Fn Harness."""
+ logging_service_descriptor = beam_fn_api_pb2.ApiServiceDescriptor()
+ text_format.Merge(os.environ['LOGGING_API_SERVICE_DESCRIPTOR'],
+ logging_service_descriptor)
+
+ # Send all logs to the runner.
+ fn_log_handler = FnApiLogRecordHandler(logging_service_descriptor)
+ # TODO(vikasrk): This should be picked up from pipeline options.
+ logging.getLogger().setLevel(logging.INFO)
+ logging.getLogger().addHandler(fn_log_handler)
+
+ try:
+ logging.info('Python sdk harness started.')
+ service_descriptor = beam_fn_api_pb2.ApiServiceDescriptor()
+ text_format.Merge(os.environ['CONTROL_API_SERVICE_DESCRIPTOR'],
+ service_descriptor)
+ # TODO(robertwb): Support credentials.
+ assert not service_descriptor.oauth2_client_credentials_grant.url
+ channel = grpc.insecure_channel(service_descriptor.url)
+ SdkHarness(channel).run()
+ logging.info('Python sdk harness exiting.')
+ except: # pylint: disable=broad-except
+ logging.exception('Python sdk harness failed: ')
+ raise
+ finally:
+ fn_log_handler.close()
+
+
+if __name__ == '__main__':
+ main(sys.argv)