You are viewing a plain text version of this content. The canonical link for it is here.
Posted to github@beam.apache.org by GitBox <gi...@apache.org> on 2022/09/22 15:42:36 UTC

[GitHub] [beam] alxmrs commented on a diff in pull request #22421: Initial DaskRunner for Beam

alxmrs commented on code in PR #22421:
URL: https://github.com/apache/beam/pull/22421#discussion_r977818314


##########
sdks/python/apache_beam/runners/dask/dask_runner.py:
##########
@@ -0,0 +1,173 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+"""DaskRunner, executing remote jobs on Dask.distributed.
+
+The DaskRunner is a runner implementation that executes a graph of
+transformations across processes and workers via Dask distributed's
+scheduler.
+"""
+import dataclasses
+
+import argparse
+import typing as t
+
+from apache_beam import pvalue
+from apache_beam.options.pipeline_options import PipelineOptions
+from apache_beam.pipeline import AppliedPTransform
+from apache_beam.pipeline import PipelineVisitor
+from apache_beam.runners.dask.overrides import dask_overrides
+from apache_beam.runners.dask.transform_evaluator import TRANSLATIONS
+from apache_beam.runners.dask.transform_evaluator import NoOp
+from apache_beam.runners.direct.direct_runner import BundleBasedDirectRunner
+from apache_beam.runners.runner import PipelineResult
+from apache_beam.runners.runner import PipelineState
+from apache_beam.utils.interactive_utils import is_in_notebook
+
+
+class DaskOptions(PipelineOptions):
+
+  @staticmethod
+  def _parse_timeout(candidate):
+    try:
+      return int(candidate)
+    except (TypeError, ValueError):
+      import dask
+      return dask.config.no_default
+
+  @classmethod
+  def _add_argparse_args(cls, parser: argparse.ArgumentParser) -> None:
+    parser.add_argument('--dask_client_address', dest='address', type=str, default=None,
+                        help='Address of a dask Scheduler server. Will '
+                             'default to a `dask.LocalCluster()`.')
+    parser.add_argument('--dask_connection_timeout', dest='timeout',
+                        type=DaskOptions._parse_timeout,
+                        help='Timeout duration for initial connection to the '
+                             'scheduler.')
+    parser.add_argument('--dask_scheduler_file', type=str, default=None,
+                        help='Path to a file with scheduler information if '
+                             'available.')
+    # TODO(alxr): Add options for security.
+    parser.add_argument('--dask_client_name', dest='name', type=str,
+                        default=None,
+                        help='Gives the client a name that will be included '
+                             'in logs generated on the scheduler for matters '
+                             'relating to this client.')
+    parser.add_argument('--dask_connection_limit', dest='connection_limit',
+                        type=int, default=512,
+                        help='The number of open comms to maintain at once in '
+                             'the connection pool.')
+
+
+@dataclasses.dataclass
+class DaskRunnerResult(PipelineResult):
+  from dask import distributed
+
+  client: distributed.Client
+  futures: t.Sequence[distributed.Future]
+
+  def __post_init__(self):
+    super().__init__(PipelineState.RUNNING)
+
+  def wait_until_finish(self, duration=None) -> PipelineState:
+    try:
+      if duration is not None:
+        # Convert milliseconds to seconds
+        duration /= 1000
+      self.client.wait_for_workers(timeout=duration)
+      self.client.gather(self.futures, errors='raise', asynchronous=True)
+      self._state = PipelineState.DONE
+    except:  # pylint: disable=broad-except
+      self._state = PipelineState.FAILED
+      raise
+    # finally:
+    #   self.client.close(timeout=duration)
+    return self._state
+
+  def cancel(self) -> PipelineState:
+    self._state = PipelineState.CANCELLING
+    self.client.cancel(self.futures)
+    self._state = PipelineState.CANCELLED
+    return self._state
+
+  def metrics(self):
+    # TODO(alxr): Collect and return metrics...
+    raise NotImplementedError('collecting metrics will come later!')
+
+
+class DaskRunner(BundleBasedDirectRunner):
+  """Executes a pipeline on a Dask distributed client."""
+
+  @staticmethod
+  def to_dask_bag_visitor() -> PipelineVisitor:
+    from dask import bag as db
+
+    @dataclasses.dataclass
+    class DaskBagVisitor(PipelineVisitor):
+      bags: t.Dict[AppliedPTransform, db.Bag] = dataclasses.field(
+        default_factory=dict)
+
+      def visit_transform(self, transform_node: AppliedPTransform) -> None:
+        op_class = TRANSLATIONS.get(transform_node.transform.__class__, NoOp)
+        op = op_class(transform_node)
+
+        inputs = list(transform_node.inputs)
+        if inputs:
+          bag_inputs = []
+          for input_value in inputs:
+            if isinstance(input_value, pvalue.PBegin):
+              bag_inputs.append(None)
+
+            prev_op = input_value.producer
+            if prev_op in self.bags:
+              bag_inputs.append(self.bags[prev_op])
+
+          if len(bag_inputs) == 1:
+            self.bags[transform_node] = op.apply(bag_inputs[0])
+          else:
+            self.bags[transform_node] = op.apply(bag_inputs)
+
+        else:
+          self.bags[transform_node] = op.apply(None)
+
+    return DaskBagVisitor()
+
+  @staticmethod
+  def is_fnapi_compatible():
+    return False
+
+  def run_pipeline(self, pipeline, options):
+    # TODO(alxr): Create interactive notebook support.
+    if is_in_notebook():
+      raise NotImplementedError('interactive support will come later!')
+
+    try:
+      import dask.distributed as ddist
+    except ImportError:
+      raise ImportError(
+        'DaskRunner is not available. Please install apache_beam[dask].')
+
+    dask_options = options.view_as(DaskOptions).get_all_options(
+      drop_default=True)
+    client = ddist.Client(**dask_options)

Review Comment:
   This is my first runner – @pabloem can probably weigh in better than I can wrt your question. However, what makes sense to me is that each Beam runner should clean up its environment between each run, including in tests. 
   
   This probably should happen in the `DaskRunnerResult` object. Do you have any recommendations on the best way to clean up dask (distributed)? 



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscribe@beam.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org