You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by al...@apache.org on 2017/12/15 21:52:17 UTC
[beam] branch master updated: [BEAM-3189] Sdk worker multithreading
(#4134)
This is an automated email from the ASF dual-hosted git repository.
altay pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git
The following commit(s) were added to refs/heads/master by this push:
new 049070b [BEAM-3189] Sdk worker multithreading (#4134)
049070b is described below
commit 049070be4b4e6f1fc16092b7308ed4054e6ef76a
Author: Ankur <an...@users.noreply.github.com>
AuthorDate: Fri Dec 15 13:52:14 2017 -0800
[BEAM-3189] Sdk worker multithreading (#4134)
* Adding debug server to sdk worker to get threaddumps
* Adding multi threaded function registration test
* Wrapping SDKWoker to associate more state to it.
* Making multiple workers to work in parallel
* Adding experimental option for worker_threads
* Fixing experimental flag for worker_threads
* Fix Dataplane Concurrency bug
* Making method prefix constant.
* Simplifying code and selecting worker based on availability
* Fixing lint errors
* Moving worker count extraction to sdk_worker_main
* Adding test Cases and comments
* Fixing lint error
* Making worker_count no default.
---
.../runners/portability/fn_api_runner.py | 9 +-
.../runners/portability/fn_api_runner_test.py | 13 +-
.../apache_beam/runners/worker/data_plane.py | 23 +--
.../apache_beam/runners/worker/sdk_worker.py | 167 ++++++++++++++-------
.../apache_beam/runners/worker/sdk_worker_main.py | 38 ++++-
.../runners/worker/sdk_worker_main_test.py | 33 +++-
.../apache_beam/runners/worker/sdk_worker_test.py | 68 ++++++---
7 files changed, 258 insertions(+), 93 deletions(-)
diff --git a/sdks/python/apache_beam/runners/portability/fn_api_runner.py b/sdks/python/apache_beam/runners/portability/fn_api_runner.py
index 9b143c6..a7a73f4 100644
--- a/sdks/python/apache_beam/runners/portability/fn_api_runner.py
+++ b/sdks/python/apache_beam/runners/portability/fn_api_runner.py
@@ -855,7 +855,7 @@ class FnApiRunner(runner.PipelineRunner):
self.data_plane_handler = data_plane.InMemoryDataChannel()
self.worker = sdk_worker.SdkWorker(
self.state_handler, data_plane.InMemoryDataChannelFactory(
- self.data_plane_handler.inverse()))
+ self.data_plane_handler.inverse()), {})
def push(self, request):
logging.debug('CONTROL REQUEST %s', request)
@@ -907,8 +907,11 @@ class FnApiRunner(runner.PipelineRunner):
self.data_server.start()
self.control_server.start()
- self.worker = (self.sdk_harness_factory or sdk_worker.SdkHarness)(
- 'localhost:%s' % self.control_port)
+ self.worker = self.sdk_harness_factory(
+ 'localhost:%s' % self.control_port
+ ) if self.sdk_harness_factory else sdk_worker.SdkHarness(
+ 'localhost:%s' % self.control_port, worker_count=1)
+
self.worker_thread = threading.Thread(
name='run_worker', target=self.worker.run)
logging.info('starting worker')
diff --git a/sdks/python/apache_beam/runners/portability/fn_api_runner_test.py b/sdks/python/apache_beam/runners/portability/fn_api_runner_test.py
index 249eece..eb297ab 100644
--- a/sdks/python/apache_beam/runners/portability/fn_api_runner_test.py
+++ b/sdks/python/apache_beam/runners/portability/fn_api_runner_test.py
@@ -14,7 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
-
+import functools
import logging
import time
import unittest
@@ -22,6 +22,7 @@ import unittest
import apache_beam as beam
from apache_beam.runners.portability import fn_api_runner
from apache_beam.runners.portability import maptask_executor_runner_test
+from apache_beam.runners.worker import sdk_worker
from apache_beam.testing.util import assert_that
from apache_beam.testing.util import equal_to
from apache_beam.transforms import window
@@ -155,6 +156,16 @@ class FnApiRunnerTestWithGrpc(FnApiRunnerTest):
runner=fn_api_runner.FnApiRunner(use_grpc=True))
+class FnApiRunnerTestWithGrpcMultiThreaded(FnApiRunnerTest):
+
+ def create_pipeline(self):
+ return beam.Pipeline(
+ runner=fn_api_runner.FnApiRunner(
+ use_grpc=True,
+ sdk_harness_factory=functools.partial(
+ sdk_worker.SdkHarness, worker_count=2)))
+
+
if __name__ == '__main__':
logging.getLogger().setLevel(logging.INFO)
unittest.main()
diff --git a/sdks/python/apache_beam/runners/worker/data_plane.py b/sdks/python/apache_beam/runners/worker/data_plane.py
index f2a3751..b3d4854 100644
--- a/sdks/python/apache_beam/runners/worker/data_plane.py
+++ b/sdks/python/apache_beam/runners/worker/data_plane.py
@@ -277,20 +277,23 @@ class GrpcClientDataChannelFactory(DataChannelFactory):
def __init__(self):
self._data_channel_cache = {}
+ self._lock = threading.Lock()
def create_data_channel(self, remote_grpc_port):
url = remote_grpc_port.api_service_descriptor.url
if url not in self._data_channel_cache:
- logging.info('Creating channel for %s', url)
- grpc_channel = grpc.insecure_channel(
- url,
- # Options to have no limits (-1) on the size of the messages
- # received or sent over the data plane. The actual buffer size is
- # controlled in a layer above.
- options=[("grpc.max_receive_message_length", -1),
- ("grpc.max_send_message_length", -1)])
- self._data_channel_cache[url] = GrpcClientDataChannel(
- beam_fn_api_pb2_grpc.BeamFnDataStub(grpc_channel))
+ with self._lock:
+ if url not in self._data_channel_cache:
+ logging.info('Creating channel for %s', url)
+ grpc_channel = grpc.insecure_channel(
+ url,
+ # Options to have no limits (-1) on the size of the messages
+ # received or sent over the data plane. The actual buffer size is
+ # controlled in a layer above.
+ options=[("grpc.max_receive_message_length", -1),
+ ("grpc.max_send_message_length", -1)])
+ self._data_channel_cache[url] = GrpcClientDataChannel(
+ beam_fn_api_pb2_grpc.BeamFnDataStub(grpc_channel))
return self._data_channel_cache[url]
def close(self):
diff --git a/sdks/python/apache_beam/runners/worker/sdk_worker.py b/sdks/python/apache_beam/runners/worker/sdk_worker.py
index fec844e..980357e 100644
--- a/sdks/python/apache_beam/runners/worker/sdk_worker.py
+++ b/sdks/python/apache_beam/runners/worker/sdk_worker.py
@@ -14,14 +14,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
-
"""SDK harness for executing Python Fns via the Fn API."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-import functools
import logging
import Queue as queue
import sys
@@ -38,27 +36,54 @@ from apache_beam.runners.worker import data_plane
class SdkHarness(object):
+ REQUEST_METHOD_PREFIX = '_request_'
- def __init__(self, control_address):
+ def __init__(self, control_address, worker_count):
+ self._worker_count = worker_count
+ self._worker_index = 0
self._control_channel = grpc.insecure_channel(control_address)
self._data_channel_factory = data_plane.GrpcClientDataChannelFactory()
- # TODO: Ensure thread safety to run with more than 1 thread.
- self._default_work_thread_pool = futures.ThreadPoolExecutor(max_workers=1)
+ self.workers = queue.Queue()
+ # one thread is enough for getting the progress report.
+ # Assumption:
+ # Progress report generation should not do IO or wait on other resources.
+ # Without wait, having multiple threads will not improve performance and
+ # will only add complexity.
self._progress_thread_pool = futures.ThreadPoolExecutor(max_workers=1)
+ self._process_thread_pool = futures.ThreadPoolExecutor(
+ max_workers=self._worker_count)
+ self._instruction_id_vs_worker = {}
+ self._fns = {}
+ self._responses = queue.Queue()
+ self._process_bundle_queue = queue.Queue()
+ logging.info('Initializing SDKHarness with %s workers.', self._worker_count)
def run(self):
control_stub = beam_fn_api_pb2_grpc.BeamFnControlStub(self._control_channel)
- state_stub = beam_fn_api_pb2_grpc.BeamFnStateStub(self._control_channel)
- state_handler = GrpcStateHandler(state_stub)
- state_handler.start()
- self.worker = SdkWorker(state_handler, self._data_channel_factory)
-
- responses = queue.Queue()
no_more_work = object()
+ # Create workers
+ for _ in range(self._worker_count):
+ state_handler = GrpcStateHandler(
+ beam_fn_api_pb2_grpc.BeamFnStateStub(self._control_channel))
+ state_handler.start()
+ # SdkHarness manage function registration and share self._fns with all
+ # the workers. This is needed because function registration (register)
+ # and exceution(process_bundle) are send over different request and we
+ # do not really know which woker is going to process bundle
+ # for a function till we get process_bundle request. Moreover
+ # same function is reused by different process bundle calls and
+ # potentially get executed by different worker. Hence we need a
+ # centralized function list shared among all the workers.
+ self.workers.put(
+ SdkWorker(
+ state_handler=state_handler,
+ data_channel_factory=self._data_channel_factory,
+ fns=self._fns))
+
def get_responses():
while True:
- response = responses.get()
+ response = self._responses.get()
if response is no_more_work:
return
yield response
@@ -66,54 +91,86 @@ class SdkHarness(object):
for work_request in control_stub.Control(get_responses()):
logging.info('Got work %s', work_request.instruction_id)
request_type = work_request.WhichOneof('request')
- # WhichOneOf returns the name of the set field as a single string
- if request_type in ['process_bundle_progress']:
- thread_pool = self._progress_thread_pool
- else:
- thread_pool = self._default_work_thread_pool
-
- # Need this wrapper to capture the original stack trace.
- def do_instruction(request):
- try:
- return self.worker.do_instruction(request)
- except Exception as e: # pylint: disable=broad-except
- traceback_str = traceback.format_exc(e)
- raise Exception("Error processing request. Original traceback "
- "is\n%s\n" % traceback_str)
-
- def handle_response(request, response_future):
- try:
- response = response_future.result()
- except Exception as e: # pylint: disable=broad-except
- logging.error(
- 'Error processing instruction %s',
- request.instruction_id,
- exc_info=True)
- response = beam_fn_api_pb2.InstructionResponse(
- instruction_id=request.instruction_id,
- error=str(e))
- responses.put(response)
-
- thread_pool.submit(do_instruction, work_request).add_done_callback(
- functools.partial(handle_response, work_request))
-
- logging.info("No more requests from control plane")
- logging.info("SDK Harness waiting for in-flight requests to complete")
+ # Name spacing the request method with 'request_'. The called method
+ # will be like self.request_register(request)
+ getattr(self, SdkHarness.REQUEST_METHOD_PREFIX + request_type)(
+ work_request)
+
+ logging.info('No more requests from control plane')
+ logging.info('SDK Harness waiting for in-flight requests to complete')
# Wait until existing requests are processed.
self._progress_thread_pool.shutdown()
- self._default_work_thread_pool.shutdown()
+ self._process_thread_pool.shutdown()
# get_responses may be blocked on responses.get(), but we need to return
# control to its caller.
- responses.put(no_more_work)
+ self._responses.put(no_more_work)
self._data_channel_factory.close()
- state_handler.done()
+ # Stop all the workers and clean all the associated resources
+ for worker in self.workers.queue:
+ worker.state_handler.done()
logging.info('Done consuming work.')
+ def _execute(self, task, request):
+ try:
+ response = task()
+ except Exception as e: # pylint: disable=broad-except
+ traceback.print_exc(file=sys.stderr)
+ logging.error(
+ 'Error processing instruction %s. Original traceback is\n%s\n',
+ request.instruction_id,
+ traceback.format_exc(e),
+ exc_info=True)
+ response = beam_fn_api_pb2.InstructionResponse(
+ instruction_id=request.instruction_id, error=str(e))
+ self._responses.put(response)
+
+ def _request_register(self, request):
+
+ def task():
+ for process_bundle_descriptor in getattr(
+ request, request.WhichOneof('request')).process_bundle_descriptor:
+ self._fns[process_bundle_descriptor.id] = process_bundle_descriptor
+
+ return beam_fn_api_pb2.InstructionResponse(
+ instruction_id=request.instruction_id,
+ register=beam_fn_api_pb2.RegisterResponse())
+
+ self._execute(task, request)
+
+ def _request_process_bundle(self, request):
+
+ def task():
+ # Take the free worker. Wait till a worker is free.
+ worker = self.workers.get()
+ # Get the first work item in the queue
+ work = self._process_bundle_queue.get()
+ # add the instuction_id vs worker map for progress reporting lookup
+ self._instruction_id_vs_worker[work.instruction_id] = worker
+ try:
+ self._execute(lambda: worker.do_instruction(work), work)
+ finally:
+ # Delete the instruction_id <-> worker mapping
+ self._instruction_id_vs_worker.pop(work.instruction_id, None)
+ # Put the worker back in the free worker pool
+ self.workers.put(worker)
+
+ # Create a task for each process_bundle request and schedule it
+ self._process_bundle_queue.put(request)
+ self._process_thread_pool.submit(task)
+
+ def _request_process_bundle_progress(self, request):
+ worker = self._instruction_id_vs_worker[request.instruction_id]
+
+ def task():
+ self._execute(lambda: worker.do_instruction(request), request)
+
+ self._progress_thread_pool.submit(task)
+
class SdkWorker(object):
- def __init__(self, state_handler, data_channel_factory):
- self.fns = {}
+ def __init__(self, state_handler, data_channel_factory, fns):
+ self.fns = fns
self.state_handler = state_handler
self.data_channel_factory = data_channel_factory
self.bundle_processors = {}
@@ -122,8 +179,8 @@ class SdkWorker(object):
request_type = request.WhichOneof('request')
if request_type:
# E.g. if register is set, this will call self.register(request.register))
- return getattr(self, request_type)(
- getattr(request, request_type), request.instruction_id)
+ return getattr(self, request_type)(getattr(request, request_type),
+ request.instruction_id)
else:
raise NotImplementedError
@@ -138,8 +195,7 @@ class SdkWorker(object):
self.bundle_processors[
instruction_id] = processor = bundle_processor.BundleProcessor(
self.fns[request.process_bundle_descriptor_reference],
- self.state_handler,
- self.data_channel_factory)
+ self.state_handler, self.data_channel_factory)
try:
processor.process_bundle(instruction_id)
finally:
@@ -176,6 +232,7 @@ class GrpcStateHandler(object):
if request is self._DONE or self._done:
break
yield request
+
responses = self._state_stub.State(request_iter())
def pull_responses():
@@ -187,6 +244,7 @@ class GrpcStateHandler(object):
except: # pylint: disable=bare-except
self._exc_info = sys.exc_info()
raise
+
reader = threading.Thread(target=pull_responses, name='read_state')
reader.daemon = True
reader.start()
@@ -239,6 +297,7 @@ class GrpcStateHandler(object):
class _Future(object):
"""A simple future object to implement blocking requests.
"""
+
def __init__(self):
self._event = threading.Event()
diff --git a/sdks/python/apache_beam/runners/worker/sdk_worker_main.py b/sdks/python/apache_beam/runners/worker/sdk_worker_main.py
index 1db8b29..8671005 100644
--- a/sdks/python/apache_beam/runners/worker/sdk_worker_main.py
+++ b/sdks/python/apache_beam/runners/worker/sdk_worker_main.py
@@ -20,6 +20,7 @@ import BaseHTTPServer
import json
import logging
import os
+import re
import sys
import threading
import traceback
@@ -128,7 +129,9 @@ def main(unused_argv):
service_descriptor)
# TODO(robertwb): Support credentials.
assert not service_descriptor.oauth2_client_credentials_grant.url
- SdkHarness(service_descriptor.url).run()
+ SdkHarness(
+ control_address=service_descriptor.url,
+ worker_count=_get_worker_count(sdk_pipeline_options)).run()
logging.info('Python sdk harness exiting.')
except: # pylint: disable=broad-except
logging.exception('Python sdk harness failed: ')
@@ -138,6 +141,39 @@ def main(unused_argv):
fn_log_handler.close()
+def _get_worker_count(pipeline_options):
+ """Extract worker count from the pipeline_options.
+
+ This defines how many SdkWorkers will be started in this Python process.
+ And each SdkWorker will have its own thread to process data. Name of the
+ experimental parameter is 'worker_threads'
+ Example Usage in the Command Line:
+ --experimental worker_threads=1
+
+ Note: worker_threads is an experimental flag and might not be available in
+ future releases.
+
+ Returns:
+ an int containing the worker_threads to use. Default is 1
+ """
+ pipeline_options = pipeline_options.get(
+ 'options') if pipeline_options.has_key('options') else {}
+ experiments = pipeline_options.get(
+ 'experiments'
+ ) if pipeline_options and pipeline_options.has_key('experiments') else []
+
+ experiments = experiments if experiments else []
+
+ for experiment in experiments:
+ # There should only be 1 match so returning from the loop
+ if re.match(r'worker_threads=', experiment):
+ return int(
+ re.match(r'worker_threads=(?P<worker_threads>.*)',
+ experiment).group('worker_threads'))
+
+ return 1
+
+
def _load_main_session(semi_persistent_directory):
"""Loads a pickled main session from the path specified."""
if semi_persistent_directory:
diff --git a/sdks/python/apache_beam/runners/worker/sdk_worker_main_test.py b/sdks/python/apache_beam/runners/worker/sdk_worker_main_test.py
index 9305c99..0e312f5 100644
--- a/sdks/python/apache_beam/runners/worker/sdk_worker_main_test.py
+++ b/sdks/python/apache_beam/runners/worker/sdk_worker_main_test.py
@@ -20,6 +20,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import json
import logging
import unittest
@@ -34,11 +35,39 @@ class SdkWorkerMainTest(unittest.TestCase):
def wrapped_method_for_test():
lines = sdk_worker_main.StatusServer.get_thread_dump()
threaddump = '\n'.join(lines)
- self.assertRegexpMatches(threaddump, ".*wrapped_method_for_test.*")
+ self.assertRegexpMatches(threaddump, '.*wrapped_method_for_test.*')
wrapped_method_for_test()
+ def test_work_count_default_value(self):
+ self._check_worker_count('{}', 1)
-if __name__ == "__main__":
+ def test_work_count_custom_value(self):
+ self._check_worker_count(
+ '{"options": {"experiments":["worker_threads=1"]}}', 1)
+ self._check_worker_count(
+ '{"options": {"experiments":["worker_threads=4"]}}', 4)
+ self._check_worker_count(
+ '{"options": {"experiments":["worker_threads=12"]}}', 12)
+
+ def test_work_count_wrong_format(self):
+ self._check_worker_count(
+ '{"options": {"experiments":["worker_threads="]}}', exception=True)
+ self._check_worker_count(
+ '{"options": {"experiments":["worker_threads=a"]}}', exception=True)
+ self._check_worker_count(
+ '{"options": {"experiments":["worker_threads=1a"]}}', exception=True)
+
+ def _check_worker_count(self, pipeline_options, expected=0, exception=False):
+ if exception:
+ self.assertRaises(Exception, sdk_worker_main._get_worker_count,
+ json.loads(pipeline_options))
+ else:
+ self.assertEquals(
+ sdk_worker_main._get_worker_count(json.loads(pipeline_options)),
+ expected)
+
+
+if __name__ == '__main__':
logging.getLogger().setLevel(logging.INFO)
unittest.main()
diff --git a/sdks/python/apache_beam/runners/worker/sdk_worker_test.py b/sdks/python/apache_beam/runners/worker/sdk_worker_test.py
index 2532341..c229d64 100644
--- a/sdks/python/apache_beam/runners/worker/sdk_worker_test.py
+++ b/sdks/python/apache_beam/runners/worker/sdk_worker_test.py
@@ -14,7 +14,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
-
"""Tests for apache_beam.runners.worker.sdk_worker."""
from __future__ import absolute_import
@@ -62,29 +61,54 @@ class BeamFnControlServicer(beam_fn_api_pb2_grpc.BeamFnControlServicer):
class SdkWorkerTest(unittest.TestCase):
- def test_fn_registration(self):
- process_bundle_descriptors = [
+ def _get_process_bundles(self, prefix, size):
+ return [
beam_fn_api_pb2.ProcessBundleDescriptor(
- id=str(100+ix),
+ id=str(str(prefix) + "-" + str(ix)),
transforms={
- str(ix): beam_runner_api_pb2.PTransform(unique_name=str(ix))})
- for ix in range(4)]
-
- test_controller = BeamFnControlServicer([beam_fn_api_pb2.InstructionRequest(
- register=beam_fn_api_pb2.RegisterRequest(
- process_bundle_descriptor=process_bundle_descriptors))])
-
- server = grpc.server(futures.ThreadPoolExecutor(max_workers=10))
- beam_fn_api_pb2_grpc.add_BeamFnControlServicer_to_server(
- test_controller, server)
- test_port = server.add_insecure_port("[::]:0")
- server.start()
-
- harness = sdk_worker.SdkHarness("localhost:%s" % test_port)
- harness.run()
- self.assertEqual(
- harness.worker.fns,
- {item.id: item for item in process_bundle_descriptors})
+ str(ix): beam_runner_api_pb2.PTransform(unique_name=str(ix))
+ }) for ix in range(size)
+ ]
+
+ def _check_fn_registration_multi_request(self, *args):
+ """Check the function registration calls to the sdk_harness.
+
+ Args:
+ tuple of request_count, number of process_bundles per request and workers
+ counts to process the request.
+ """
+ for (request_count, process_bundles_per_request, worker_count) in args:
+ requests = []
+ process_bundle_descriptors = []
+
+ for i in range(request_count):
+ pbd = self._get_process_bundles(i, process_bundles_per_request)
+ process_bundle_descriptors.extend(pbd)
+ requests.append(
+ beam_fn_api_pb2.InstructionRequest(
+ instruction_id=str(i),
+ register=beam_fn_api_pb2.RegisterRequest(
+ process_bundle_descriptor=process_bundle_descriptors)))
+
+ test_controller = BeamFnControlServicer(requests)
+
+ server = grpc.server(futures.ThreadPoolExecutor(max_workers=10))
+ beam_fn_api_pb2_grpc.add_BeamFnControlServicer_to_server(
+ test_controller, server)
+ test_port = server.add_insecure_port("[::]:0")
+ server.start()
+
+ harness = sdk_worker.SdkHarness(
+ "localhost:%s" % test_port, worker_count=worker_count)
+ harness.run()
+
+ for worker in harness.workers.queue:
+ self.assertEqual(worker.fns,
+ {item.id: item
+ for item in process_bundle_descriptors})
+
+ def test_fn_registration(self):
+ self._check_fn_registration_multi_request((1, 4, 1), (4, 4, 1), (4, 4, 2))
if __name__ == "__main__":
--
To stop receiving notification emails like this one, please contact
['"commits@beam.apache.org" <co...@beam.apache.org>'].