You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by al...@apache.org on 2017/12/15 21:52:17 UTC

[beam] branch master updated: [BEAM-3189] Sdk worker multithreading (#4134)

This is an automated email from the ASF dual-hosted git repository.

altay pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new 049070b  [BEAM-3189] Sdk worker multithreading (#4134)
049070b is described below

commit 049070be4b4e6f1fc16092b7308ed4054e6ef76a
Author: Ankur <an...@users.noreply.github.com>
AuthorDate: Fri Dec 15 13:52:14 2017 -0800

    [BEAM-3189] Sdk worker multithreading (#4134)
    
    * Adding debug server to sdk worker to get threaddumps
    
    * Adding multi threaded function registration test
    
    * Wrapping SDKWoker to associate more state to it.
    
    * Making multiple workers to work in parallel
    
    * Adding experimental option for worker_threads
    
    * Fixing experimental flag for worker_threads
    
    * Fix Dataplane Concurrency bug
    
    * Making method prefix constant.
    
    * Simplifying code and selecting worker based on availability
    
    * Fixing lint errors
    
    * Moving worker count extraction to sdk_worker_main
    
    * Adding test Cases and comments
    
    * Fixing lint error
    
    * Making worker_count no default.
---
 .../runners/portability/fn_api_runner.py           |   9 +-
 .../runners/portability/fn_api_runner_test.py      |  13 +-
 .../apache_beam/runners/worker/data_plane.py       |  23 +--
 .../apache_beam/runners/worker/sdk_worker.py       | 167 ++++++++++++++-------
 .../apache_beam/runners/worker/sdk_worker_main.py  |  38 ++++-
 .../runners/worker/sdk_worker_main_test.py         |  33 +++-
 .../apache_beam/runners/worker/sdk_worker_test.py  |  68 ++++++---
 7 files changed, 258 insertions(+), 93 deletions(-)

diff --git a/sdks/python/apache_beam/runners/portability/fn_api_runner.py b/sdks/python/apache_beam/runners/portability/fn_api_runner.py
index 9b143c6..a7a73f4 100644
--- a/sdks/python/apache_beam/runners/portability/fn_api_runner.py
+++ b/sdks/python/apache_beam/runners/portability/fn_api_runner.py
@@ -855,7 +855,7 @@ class FnApiRunner(runner.PipelineRunner):
       self.data_plane_handler = data_plane.InMemoryDataChannel()
       self.worker = sdk_worker.SdkWorker(
           self.state_handler, data_plane.InMemoryDataChannelFactory(
-              self.data_plane_handler.inverse()))
+              self.data_plane_handler.inverse()), {})
 
     def push(self, request):
       logging.debug('CONTROL REQUEST %s', request)
@@ -907,8 +907,11 @@ class FnApiRunner(runner.PipelineRunner):
       self.data_server.start()
       self.control_server.start()
 
-      self.worker = (self.sdk_harness_factory or sdk_worker.SdkHarness)(
-          'localhost:%s' % self.control_port)
+      self.worker = self.sdk_harness_factory(
+          'localhost:%s' % self.control_port
+      ) if self.sdk_harness_factory else sdk_worker.SdkHarness(
+          'localhost:%s' % self.control_port, worker_count=1)
+
       self.worker_thread = threading.Thread(
           name='run_worker', target=self.worker.run)
       logging.info('starting worker')
diff --git a/sdks/python/apache_beam/runners/portability/fn_api_runner_test.py b/sdks/python/apache_beam/runners/portability/fn_api_runner_test.py
index 249eece..eb297ab 100644
--- a/sdks/python/apache_beam/runners/portability/fn_api_runner_test.py
+++ b/sdks/python/apache_beam/runners/portability/fn_api_runner_test.py
@@ -14,7 +14,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 #
-
+import functools
 import logging
 import time
 import unittest
@@ -22,6 +22,7 @@ import unittest
 import apache_beam as beam
 from apache_beam.runners.portability import fn_api_runner
 from apache_beam.runners.portability import maptask_executor_runner_test
+from apache_beam.runners.worker import sdk_worker
 from apache_beam.testing.util import assert_that
 from apache_beam.testing.util import equal_to
 from apache_beam.transforms import window
@@ -155,6 +156,16 @@ class FnApiRunnerTestWithGrpc(FnApiRunnerTest):
         runner=fn_api_runner.FnApiRunner(use_grpc=True))
 
 
+class FnApiRunnerTestWithGrpcMultiThreaded(FnApiRunnerTest):
+
+  def create_pipeline(self):
+    return beam.Pipeline(
+        runner=fn_api_runner.FnApiRunner(
+            use_grpc=True,
+            sdk_harness_factory=functools.partial(
+                sdk_worker.SdkHarness, worker_count=2)))
+
+
 if __name__ == '__main__':
   logging.getLogger().setLevel(logging.INFO)
   unittest.main()
diff --git a/sdks/python/apache_beam/runners/worker/data_plane.py b/sdks/python/apache_beam/runners/worker/data_plane.py
index f2a3751..b3d4854 100644
--- a/sdks/python/apache_beam/runners/worker/data_plane.py
+++ b/sdks/python/apache_beam/runners/worker/data_plane.py
@@ -277,20 +277,23 @@ class GrpcClientDataChannelFactory(DataChannelFactory):
 
   def __init__(self):
     self._data_channel_cache = {}
+    self._lock = threading.Lock()
 
   def create_data_channel(self, remote_grpc_port):
     url = remote_grpc_port.api_service_descriptor.url
     if url not in self._data_channel_cache:
-      logging.info('Creating channel for %s', url)
-      grpc_channel = grpc.insecure_channel(
-          url,
-          # Options to have no limits (-1) on the size of the messages
-          # received or sent over the data plane. The actual buffer size is
-          # controlled in a layer above.
-          options=[("grpc.max_receive_message_length", -1),
-                   ("grpc.max_send_message_length", -1)])
-      self._data_channel_cache[url] = GrpcClientDataChannel(
-          beam_fn_api_pb2_grpc.BeamFnDataStub(grpc_channel))
+      with self._lock:
+        if url not in self._data_channel_cache:
+          logging.info('Creating channel for %s', url)
+          grpc_channel = grpc.insecure_channel(
+              url,
+              # Options to have no limits (-1) on the size of the messages
+              # received or sent over the data plane. The actual buffer size is
+              # controlled in a layer above.
+              options=[("grpc.max_receive_message_length", -1),
+                       ("grpc.max_send_message_length", -1)])
+          self._data_channel_cache[url] = GrpcClientDataChannel(
+              beam_fn_api_pb2_grpc.BeamFnDataStub(grpc_channel))
     return self._data_channel_cache[url]
 
   def close(self):
diff --git a/sdks/python/apache_beam/runners/worker/sdk_worker.py b/sdks/python/apache_beam/runners/worker/sdk_worker.py
index fec844e..980357e 100644
--- a/sdks/python/apache_beam/runners/worker/sdk_worker.py
+++ b/sdks/python/apache_beam/runners/worker/sdk_worker.py
@@ -14,14 +14,12 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 #
-
 """SDK harness for executing Python Fns via the Fn API."""
 
 from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
-import functools
 import logging
 import Queue as queue
 import sys
@@ -38,27 +36,54 @@ from apache_beam.runners.worker import data_plane
 
 
 class SdkHarness(object):
+  REQUEST_METHOD_PREFIX = '_request_'
 
-  def __init__(self, control_address):
+  def __init__(self, control_address, worker_count):
+    self._worker_count = worker_count
+    self._worker_index = 0
     self._control_channel = grpc.insecure_channel(control_address)
     self._data_channel_factory = data_plane.GrpcClientDataChannelFactory()
-    # TODO: Ensure thread safety to run with more than 1 thread.
-    self._default_work_thread_pool = futures.ThreadPoolExecutor(max_workers=1)
+    self.workers = queue.Queue()
+    # one thread is enough for getting the progress report.
+    # Assumption:
+    # Progress report generation should not do IO or wait on other resources.
+    #  Without wait, having multiple threads will not improve performance and
+    #  will only add complexity.
     self._progress_thread_pool = futures.ThreadPoolExecutor(max_workers=1)
+    self._process_thread_pool = futures.ThreadPoolExecutor(
+        max_workers=self._worker_count)
+    self._instruction_id_vs_worker = {}
+    self._fns = {}
+    self._responses = queue.Queue()
+    self._process_bundle_queue = queue.Queue()
+    logging.info('Initializing SDKHarness with %s workers.', self._worker_count)
 
   def run(self):
     control_stub = beam_fn_api_pb2_grpc.BeamFnControlStub(self._control_channel)
-    state_stub = beam_fn_api_pb2_grpc.BeamFnStateStub(self._control_channel)
-    state_handler = GrpcStateHandler(state_stub)
-    state_handler.start()
-    self.worker = SdkWorker(state_handler, self._data_channel_factory)
-
-    responses = queue.Queue()
     no_more_work = object()
 
+    # Create workers
+    for _ in range(self._worker_count):
+      state_handler = GrpcStateHandler(
+          beam_fn_api_pb2_grpc.BeamFnStateStub(self._control_channel))
+      state_handler.start()
+      # SdkHarness manage function registration and share self._fns with all
+      # the workers. This is needed because function registration (register)
+      # and exceution(process_bundle) are send over different request and we
+      # do not really know which woker is going to process bundle
+      # for a function till we get process_bundle request. Moreover
+      # same function is reused by different process bundle calls and
+      # potentially get executed by different worker. Hence we need a
+      # centralized function list shared among all the workers.
+      self.workers.put(
+          SdkWorker(
+              state_handler=state_handler,
+              data_channel_factory=self._data_channel_factory,
+              fns=self._fns))
+
     def get_responses():
       while True:
-        response = responses.get()
+        response = self._responses.get()
         if response is no_more_work:
           return
         yield response
@@ -66,54 +91,86 @@ class SdkHarness(object):
     for work_request in control_stub.Control(get_responses()):
       logging.info('Got work %s', work_request.instruction_id)
       request_type = work_request.WhichOneof('request')
-      # WhichOneOf returns the name of the set field as a single string
-      if request_type in ['process_bundle_progress']:
-        thread_pool = self._progress_thread_pool
-      else:
-        thread_pool = self._default_work_thread_pool
-
-      # Need this wrapper to capture the original stack trace.
-      def do_instruction(request):
-        try:
-          return self.worker.do_instruction(request)
-        except Exception as e:  # pylint: disable=broad-except
-          traceback_str = traceback.format_exc(e)
-          raise Exception("Error processing request. Original traceback "
-                          "is\n%s\n" % traceback_str)
-
-      def handle_response(request, response_future):
-        try:
-          response = response_future.result()
-        except Exception as e:  # pylint: disable=broad-except
-          logging.error(
-              'Error processing instruction %s',
-              request.instruction_id,
-              exc_info=True)
-          response = beam_fn_api_pb2.InstructionResponse(
-              instruction_id=request.instruction_id,
-              error=str(e))
-        responses.put(response)
-
-      thread_pool.submit(do_instruction, work_request).add_done_callback(
-          functools.partial(handle_response, work_request))
-
-    logging.info("No more requests from control plane")
-    logging.info("SDK Harness waiting for in-flight requests to complete")
+      # Name spacing the request method with 'request_'. The called method
+      # will be like self.request_register(request)
+      getattr(self, SdkHarness.REQUEST_METHOD_PREFIX + request_type)(
+          work_request)
+
+    logging.info('No more requests from control plane')
+    logging.info('SDK Harness waiting for in-flight requests to complete')
     # Wait until existing requests are processed.
     self._progress_thread_pool.shutdown()
-    self._default_work_thread_pool.shutdown()
+    self._process_thread_pool.shutdown()
     # get_responses may be blocked on responses.get(), but we need to return
     # control to its caller.
-    responses.put(no_more_work)
+    self._responses.put(no_more_work)
     self._data_channel_factory.close()
-    state_handler.done()
+    # Stop all the workers and clean all the associated resources
+    for worker in self.workers.queue:
+      worker.state_handler.done()
     logging.info('Done consuming work.')
 
+  def _execute(self, task, request):
+    try:
+      response = task()
+    except Exception as e:  # pylint: disable=broad-except
+      traceback.print_exc(file=sys.stderr)
+      logging.error(
+          'Error processing instruction %s. Original traceback is\n%s\n',
+          request.instruction_id,
+          traceback.format_exc(e),
+          exc_info=True)
+      response = beam_fn_api_pb2.InstructionResponse(
+          instruction_id=request.instruction_id, error=str(e))
+    self._responses.put(response)
+
+  def _request_register(self, request):
+
+    def task():
+      for process_bundle_descriptor in getattr(
+          request, request.WhichOneof('request')).process_bundle_descriptor:
+        self._fns[process_bundle_descriptor.id] = process_bundle_descriptor
+
+      return beam_fn_api_pb2.InstructionResponse(
+          instruction_id=request.instruction_id,
+          register=beam_fn_api_pb2.RegisterResponse())
+
+    self._execute(task, request)
+
+  def _request_process_bundle(self, request):
+
+    def task():
+      # Take the free worker. Wait till a worker is free.
+      worker = self.workers.get()
+      # Get the first work item in the queue
+      work = self._process_bundle_queue.get()
+      # add the instuction_id vs worker map for progress reporting lookup
+      self._instruction_id_vs_worker[work.instruction_id] = worker
+      try:
+        self._execute(lambda: worker.do_instruction(work), work)
+      finally:
+        # Delete the instruction_id <-> worker mapping
+        self._instruction_id_vs_worker.pop(work.instruction_id, None)
+        # Put the worker back in the free worker pool
+        self.workers.put(worker)
+
+    # Create a task for each process_bundle request and schedule it
+    self._process_bundle_queue.put(request)
+    self._process_thread_pool.submit(task)
+
+  def _request_process_bundle_progress(self, request):
+    worker = self._instruction_id_vs_worker[request.instruction_id]
+
+    def task():
+      self._execute(lambda: worker.do_instruction(request), request)
+
+    self._progress_thread_pool.submit(task)
+
 
 class SdkWorker(object):
 
-  def __init__(self, state_handler, data_channel_factory):
-    self.fns = {}
+  def __init__(self, state_handler, data_channel_factory, fns):
+    self.fns = fns
     self.state_handler = state_handler
     self.data_channel_factory = data_channel_factory
     self.bundle_processors = {}
@@ -122,8 +179,8 @@ class SdkWorker(object):
     request_type = request.WhichOneof('request')
     if request_type:
       # E.g. if register is set, this will call self.register(request.register))
-      return getattr(self, request_type)(
-          getattr(request, request_type), request.instruction_id)
+      return getattr(self, request_type)(getattr(request, request_type),
+                                         request.instruction_id)
     else:
       raise NotImplementedError
 
@@ -138,8 +195,7 @@ class SdkWorker(object):
     self.bundle_processors[
         instruction_id] = processor = bundle_processor.BundleProcessor(
             self.fns[request.process_bundle_descriptor_reference],
-            self.state_handler,
-            self.data_channel_factory)
+            self.state_handler, self.data_channel_factory)
     try:
       processor.process_bundle(instruction_id)
     finally:
@@ -176,6 +232,7 @@ class GrpcStateHandler(object):
         if request is self._DONE or self._done:
           break
         yield request
+
     responses = self._state_stub.State(request_iter())
 
     def pull_responses():
@@ -187,6 +244,7 @@ class GrpcStateHandler(object):
       except:  # pylint: disable=bare-except
         self._exc_info = sys.exc_info()
         raise
+
     reader = threading.Thread(target=pull_responses, name='read_state')
     reader.daemon = True
     reader.start()
@@ -239,6 +297,7 @@ class GrpcStateHandler(object):
 class _Future(object):
   """A simple future object to implement blocking requests.
   """
+
   def __init__(self):
     self._event = threading.Event()
 
diff --git a/sdks/python/apache_beam/runners/worker/sdk_worker_main.py b/sdks/python/apache_beam/runners/worker/sdk_worker_main.py
index 1db8b29..8671005 100644
--- a/sdks/python/apache_beam/runners/worker/sdk_worker_main.py
+++ b/sdks/python/apache_beam/runners/worker/sdk_worker_main.py
@@ -20,6 +20,7 @@ import BaseHTTPServer
 import json
 import logging
 import os
+import re
 import sys
 import threading
 import traceback
@@ -128,7 +129,9 @@ def main(unused_argv):
                       service_descriptor)
     # TODO(robertwb): Support credentials.
     assert not service_descriptor.oauth2_client_credentials_grant.url
-    SdkHarness(service_descriptor.url).run()
+    SdkHarness(
+        control_address=service_descriptor.url,
+        worker_count=_get_worker_count(sdk_pipeline_options)).run()
     logging.info('Python sdk harness exiting.')
   except:  # pylint: disable=broad-except
     logging.exception('Python sdk harness failed: ')
@@ -138,6 +141,39 @@ def main(unused_argv):
       fn_log_handler.close()
 
 
+def _get_worker_count(pipeline_options):
+  """Extract worker count from the pipeline_options.
+
+  This defines how many SdkWorkers will be started in this Python process.
+  And each SdkWorker will have its own thread to process data. Name of the
+  experimental parameter is 'worker_threads'
+  Example Usage in the Command Line:
+    --experimental worker_threads=1
+
+  Note: worker_threads is an experimental flag and might not be available in
+  future releases.
+
+  Returns:
+    an int containing the worker_threads to use. Default is 1
+  """
+  pipeline_options = pipeline_options.get(
+      'options') if pipeline_options.has_key('options') else {}
+  experiments = pipeline_options.get(
+      'experiments'
+  ) if pipeline_options and pipeline_options.has_key('experiments') else []
+
+  experiments = experiments if experiments else []
+
+  for experiment in experiments:
+    # There should only be 1 match so returning from the loop
+    if re.match(r'worker_threads=', experiment):
+      return int(
+          re.match(r'worker_threads=(?P<worker_threads>.*)',
+                   experiment).group('worker_threads'))
+
+  return 1
+
+
 def _load_main_session(semi_persistent_directory):
   """Loads a pickled main session from the path specified."""
   if semi_persistent_directory:
diff --git a/sdks/python/apache_beam/runners/worker/sdk_worker_main_test.py b/sdks/python/apache_beam/runners/worker/sdk_worker_main_test.py
index 9305c99..0e312f5 100644
--- a/sdks/python/apache_beam/runners/worker/sdk_worker_main_test.py
+++ b/sdks/python/apache_beam/runners/worker/sdk_worker_main_test.py
@@ -20,6 +20,7 @@ from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
+import json
 import logging
 import unittest
 
@@ -34,11 +35,39 @@ class SdkWorkerMainTest(unittest.TestCase):
     def wrapped_method_for_test():
       lines = sdk_worker_main.StatusServer.get_thread_dump()
       threaddump = '\n'.join(lines)
-      self.assertRegexpMatches(threaddump, ".*wrapped_method_for_test.*")
+      self.assertRegexpMatches(threaddump, '.*wrapped_method_for_test.*')
 
     wrapped_method_for_test()
 
+  def test_work_count_default_value(self):
+    self._check_worker_count('{}', 1)
 
-if __name__ == "__main__":
+  def test_work_count_custom_value(self):
+    self._check_worker_count(
+        '{"options": {"experiments":["worker_threads=1"]}}', 1)
+    self._check_worker_count(
+        '{"options": {"experiments":["worker_threads=4"]}}', 4)
+    self._check_worker_count(
+        '{"options": {"experiments":["worker_threads=12"]}}', 12)
+
+  def test_work_count_wrong_format(self):
+    self._check_worker_count(
+        '{"options": {"experiments":["worker_threads="]}}', exception=True)
+    self._check_worker_count(
+        '{"options": {"experiments":["worker_threads=a"]}}', exception=True)
+    self._check_worker_count(
+        '{"options": {"experiments":["worker_threads=1a"]}}', exception=True)
+
+  def _check_worker_count(self, pipeline_options, expected=0, exception=False):
+    if exception:
+      self.assertRaises(Exception, sdk_worker_main._get_worker_count,
+                        json.loads(pipeline_options))
+    else:
+      self.assertEquals(
+          sdk_worker_main._get_worker_count(json.loads(pipeline_options)),
+          expected)
+
+
+if __name__ == '__main__':
   logging.getLogger().setLevel(logging.INFO)
   unittest.main()
diff --git a/sdks/python/apache_beam/runners/worker/sdk_worker_test.py b/sdks/python/apache_beam/runners/worker/sdk_worker_test.py
index 2532341..c229d64 100644
--- a/sdks/python/apache_beam/runners/worker/sdk_worker_test.py
+++ b/sdks/python/apache_beam/runners/worker/sdk_worker_test.py
@@ -14,7 +14,6 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 #
-
 """Tests for apache_beam.runners.worker.sdk_worker."""
 
 from __future__ import absolute_import
@@ -62,29 +61,54 @@ class BeamFnControlServicer(beam_fn_api_pb2_grpc.BeamFnControlServicer):
 
 class SdkWorkerTest(unittest.TestCase):
 
-  def test_fn_registration(self):
-    process_bundle_descriptors = [
+  def _get_process_bundles(self, prefix, size):
+    return [
         beam_fn_api_pb2.ProcessBundleDescriptor(
-            id=str(100+ix),
+            id=str(str(prefix) + "-" + str(ix)),
             transforms={
-                str(ix): beam_runner_api_pb2.PTransform(unique_name=str(ix))})
-        for ix in range(4)]
-
-    test_controller = BeamFnControlServicer([beam_fn_api_pb2.InstructionRequest(
-        register=beam_fn_api_pb2.RegisterRequest(
-            process_bundle_descriptor=process_bundle_descriptors))])
-
-    server = grpc.server(futures.ThreadPoolExecutor(max_workers=10))
-    beam_fn_api_pb2_grpc.add_BeamFnControlServicer_to_server(
-        test_controller, server)
-    test_port = server.add_insecure_port("[::]:0")
-    server.start()
-
-    harness = sdk_worker.SdkHarness("localhost:%s" % test_port)
-    harness.run()
-    self.assertEqual(
-        harness.worker.fns,
-        {item.id: item for item in process_bundle_descriptors})
+                str(ix): beam_runner_api_pb2.PTransform(unique_name=str(ix))
+            }) for ix in range(size)
+    ]
+
+  def _check_fn_registration_multi_request(self, *args):
+    """Check the function registration calls to the sdk_harness.
+
+    Args:
+     tuple of request_count, number of process_bundles per request and workers
+     counts to process the request.
+    """
+    for (request_count, process_bundles_per_request, worker_count) in args:
+      requests = []
+      process_bundle_descriptors = []
+
+      for i in range(request_count):
+        pbd = self._get_process_bundles(i, process_bundles_per_request)
+        process_bundle_descriptors.extend(pbd)
+        requests.append(
+            beam_fn_api_pb2.InstructionRequest(
+                instruction_id=str(i),
+                register=beam_fn_api_pb2.RegisterRequest(
+                    process_bundle_descriptor=process_bundle_descriptors)))
+
+      test_controller = BeamFnControlServicer(requests)
+
+      server = grpc.server(futures.ThreadPoolExecutor(max_workers=10))
+      beam_fn_api_pb2_grpc.add_BeamFnControlServicer_to_server(
+          test_controller, server)
+      test_port = server.add_insecure_port("[::]:0")
+      server.start()
+
+      harness = sdk_worker.SdkHarness(
+          "localhost:%s" % test_port, worker_count=worker_count)
+      harness.run()
+
+      for worker in harness.workers.queue:
+        self.assertEqual(worker.fns,
+                         {item.id: item
+                          for item in process_bundle_descriptors})
+
+  def test_fn_registration(self):
+    self._check_fn_registration_multi_request((1, 4, 1), (4, 4, 1), (4, 4, 2))
 
 
 if __name__ == "__main__":

-- 
To stop receiving notification emails like this one, please contact
['"commits@beam.apache.org" <co...@beam.apache.org>'].