You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by jo...@apache.org on 2023/02/17 14:19:20 UTC

[airflow] branch main updated: Get rid of state in Apache Beam provider hook (#29503)

This is an automated email from the ASF dual-hosted git repository.

joshfell pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 7ba27e7881 Get rid of state in Apache Beam provider hook (#29503)
7ba27e7881 is described below

commit 7ba27e78812b890f0c7642d78a986fe325ff61c4
Author: dannikay <48...@users.noreply.github.com>
AuthorDate: Fri Feb 17 06:19:11 2023 -0800

    Get rid of state in Apache Beam provider hook (#29503)
    
    * Get rid of state in Apache Beam provider hook
    
    * breeze static-checks --last-commit
    
    * Fix test_dataflow.py
    
    * FIx format
    
    * Add missing import
    
    * Fix data type.
    
    * Fix parameter type
    
    * Fix parameters
    
    * Update airflow/providers/apache/beam/hooks/beam.py comment
    
    Co-authored-by: Josh Fell <48...@users.noreply.github.com>
    
    ---------
    
    Co-authored-by: Xiaochu Liu <>
    Co-authored-by: Josh Fell <48...@users.noreply.github.com>
---
 airflow/providers/apache/beam/hooks/beam.py        | 146 +++++++++++----------
 tests/providers/apache/beam/hooks/test_beam.py     |  81 ++++++------
 .../providers/google/cloud/hooks/test_dataflow.py  |  29 ++--
 3 files changed, 134 insertions(+), 122 deletions(-)

diff --git a/airflow/providers/apache/beam/hooks/beam.py b/airflow/providers/apache/beam/hooks/beam.py
index c318d17363..9ea9e7b79b 100644
--- a/airflow/providers/apache/beam/hooks/beam.py
+++ b/airflow/providers/apache/beam/hooks/beam.py
@@ -21,6 +21,7 @@ from __future__ import annotations
 import contextlib
 import copy
 import json
+import logging
 import os
 import select
 import shlex
@@ -35,7 +36,6 @@ from packaging.version import Version
 from airflow.exceptions import AirflowConfigException, AirflowException
 from airflow.hooks.base import BaseHook
 from airflow.providers.google.go_module_utils import init_module, install_dependencies
-from airflow.utils.log.logging_mixin import LoggingMixin
 from airflow.utils.python_virtualenv import prepare_virtualenv
 
 
@@ -81,81 +81,85 @@ def beam_options_to_args(options: dict) -> list[str]:
     return args
 
 
-class BeamCommandRunner(LoggingMixin):
+def process_fd(
+    proc,
+    fd,
+    log: logging.Logger,
+    process_line_callback: Callable[[str], None] | None = None,
+):
     """
-    Class responsible for running pipeline command in subprocess
+    Prints output to logs.
+
+    :param proc: subprocess.
+    :param fd: File descriptor.
+    :param process_line_callback: Optional callback which can be used to process
+        stdout and stderr to detect job id.
+    :param log: logger.
+    """
+    if fd not in (proc.stdout, proc.stderr):
+        raise Exception("No data in stderr or in stdout.")
+
+    fd_to_log = {proc.stderr: log.warning, proc.stdout: log.info}
+    func_log = fd_to_log[fd]
+
+    while True:
+        line = fd.readline().decode()
+        if not line:
+            return
+        if process_line_callback:
+            process_line_callback(line)
+        func_log(line.rstrip("\n"))
+
+
+def run_beam_command(
+    cmd: list[str],
+    log: logging.Logger,
+    process_line_callback: Callable[[str], None] | None = None,
+    working_directory: str | None = None,
+) -> None:
+    """
+    Function responsible for running pipeline command in subprocess.
 
     :param cmd: Parts of the command to be run in subprocess
     :param process_line_callback: Optional callback which can be used to process
         stdout and stderr to detect job id
     :param working_directory: Working directory
+    :param log: logger.
     """
-
-    def __init__(
-        self,
-        cmd: list[str],
-        process_line_callback: Callable[[str], None] | None = None,
-        working_directory: str | None = None,
-    ) -> None:
-        super().__init__()
-        self.log.info("Running command: %s", " ".join(shlex.quote(c) for c in cmd))
-        self.process_line_callback = process_line_callback
-        self.job_id: str | None = None
-
-        self._proc = subprocess.Popen(
-            cmd,
-            cwd=working_directory,
-            shell=False,
-            stdout=subprocess.PIPE,
-            stderr=subprocess.PIPE,
-            close_fds=True,
-        )
-
-    def _process_fd(self, fd):
-        """
-        Prints output to logs.
-
-        :param fd: File descriptor.
-        """
-        if fd not in (self._proc.stdout, self._proc.stderr):
-            raise Exception("No data in stderr or in stdout.")
-
-        fd_to_log = {self._proc.stderr: self.log.warning, self._proc.stdout: self.log.info}
-        func_log = fd_to_log[fd]
-
-        while True:
-            line = fd.readline().decode()
-            if not line:
-                return
-            if self.process_line_callback:
-                self.process_line_callback(line)
-            func_log(line.rstrip("\n"))
-
-    def wait_for_done(self) -> None:
-        """Waits for Apache Beam pipeline to complete."""
-        self.log.info("Start waiting for Apache Beam process to complete.")
-        reads = [self._proc.stderr, self._proc.stdout]
-        while True:
-            # Wait for at least one available fd.
-            readable_fds, _, _ = select.select(reads, [], [], 5)
-            if readable_fds is None:
-                self.log.info("Waiting for Apache Beam process to complete.")
-                continue
-
-            for readable_fd in readable_fds:
-                self._process_fd(readable_fd)
-
-            if self._proc.poll() is not None:
-                break
-
-        # Corner case: check if more output was created between the last read and the process termination
-        for readable_fd in reads:
-            self._process_fd(readable_fd)
-
-        self.log.info("Process exited with return code: %s", self._proc.returncode)
-
-        if self._proc.returncode != 0:
-            raise AirflowException(f"Apache Beam process failed with return code {self._proc.returncode}")
+    log.info("Running command: %s", " ".join(shlex.quote(c) for c in cmd))
+
+    proc = subprocess.Popen(
+        cmd,
+        cwd=working_directory,
+        shell=False,
+        stdout=subprocess.PIPE,
+        stderr=subprocess.PIPE,
+        close_fds=True,
+    )
+    # Waits for Apache Beam pipeline to complete.
+    log.info("Start waiting for Apache Beam process to complete.")
+    reads = [proc.stderr, proc.stdout]
+    while True:
+        # Wait for at least one available fd.
+        readable_fds, _, _ = select.select(reads, [], [], 5)
+        if readable_fds is None:
+            log.info("Waiting for Apache Beam process to complete.")
+            continue
+
+        for readable_fd in readable_fds:
+            process_fd(proc, readable_fd, log, process_line_callback)
+
+        if proc.poll() is not None:
+            break
+
+    # Corner case: check if more output was created between the last read and the process termination
+    for readable_fd in reads:
+        process_fd(proc, readable_fd, log, process_line_callback)
+
+    log.info("Process exited with return code: %s", proc.returncode)
+
+    if proc.returncode != 0:
+        raise AirflowException(f"Apache Beam process failed with return code {proc.returncode}")
 
 
 class BeamHook(BaseHook):
@@ -187,12 +191,12 @@ class BeamHook(BaseHook):
         ]
         if variables:
             cmd.extend(beam_options_to_args(variables))
-        cmd_runner = BeamCommandRunner(
+        run_beam_command(
             cmd=cmd,
             process_line_callback=process_line_callback,
             working_directory=working_directory,
+            log=self.log,
         )
-        cmd_runner.wait_for_done()
 
     def start_python_pipeline(
         self,
diff --git a/tests/providers/apache/beam/hooks/test_beam.py b/tests/providers/apache/beam/hooks/test_beam.py
index 80cf26687d..3e3e895cf9 100644
--- a/tests/providers/apache/beam/hooks/test_beam.py
+++ b/tests/providers/apache/beam/hooks/test_beam.py
@@ -21,12 +21,12 @@ import os
 import re
 import subprocess
 from unittest import mock
-from unittest.mock import MagicMock
+from unittest.mock import ANY, MagicMock
 
 import pytest
 
 from airflow.exceptions import AirflowException
-from airflow.providers.apache.beam.hooks.beam import BeamCommandRunner, BeamHook, beam_options_to_args
+from airflow.providers.apache.beam.hooks.beam import BeamHook, beam_options_to_args, run_beam_command
 
 PY_FILE = "apache_beam.examples.wordcount"
 JAR_FILE = "unitest.jar"
@@ -57,11 +57,10 @@ INFO: To cancel the job using the 'gcloud' tool, run:
 
 
 class TestBeamHook:
-    @mock.patch(BEAM_STRING.format("BeamCommandRunner"))
+    @mock.patch(BEAM_STRING.format("run_beam_command"))
     @mock.patch("airflow.providers.apache.beam.hooks.beam.subprocess.check_output", return_value=b"2.39.0")
     def test_start_python_pipeline(self, mock_check_output, mock_runner):
         hook = BeamHook(runner=DEFAULT_RUNNER)
-        wait_for_done = mock_runner.return_value.wait_for_done
         process_line_callback = MagicMock()
 
         hook.start_python_pipeline(
@@ -80,9 +79,11 @@ class TestBeamHook:
             "--labels=foo=bar",
         ]
         mock_runner.assert_called_once_with(
-            cmd=expected_cmd, process_line_callback=process_line_callback, working_directory=None
+            cmd=expected_cmd,
+            process_line_callback=process_line_callback,
+            working_directory=None,
+            log=ANY,
         )
-        wait_for_done.assert_called_once_with()
 
     @mock.patch("airflow.providers.apache.beam.hooks.beam.subprocess.check_output", return_value=b"2.35.0")
     def test_start_python_pipeline_unsupported_option(self, mock_check_output):
@@ -113,13 +114,12 @@ class TestBeamHook:
             pytest.param("python3.6", id="major.minor python version"),
         ],
     )
-    @mock.patch(BEAM_STRING.format("BeamCommandRunner"))
+    @mock.patch(BEAM_STRING.format("run_beam_command"))
     @mock.patch("airflow.providers.apache.beam.hooks.beam.subprocess.check_output", return_value=b"2.39.0")
     def test_start_python_pipeline_with_custom_interpreter(
         self, mock_check_output, mock_runner, py_interpreter
     ):
         hook = BeamHook(runner=DEFAULT_RUNNER)
-        wait_for_done = mock_runner.return_value.wait_for_done
         process_line_callback = MagicMock()
 
         hook.start_python_pipeline(
@@ -139,9 +139,11 @@ class TestBeamHook:
             "--labels=foo=bar",
         ]
         mock_runner.assert_called_once_with(
-            cmd=expected_cmd, process_line_callback=process_line_callback, working_directory=None
+            cmd=expected_cmd,
+            process_line_callback=process_line_callback,
+            working_directory=None,
+            log=ANY,
         )
-        wait_for_done.assert_called_once_with()
 
     @pytest.mark.parametrize(
         "current_py_requirements, current_py_system_site_packages",
@@ -152,7 +154,7 @@ class TestBeamHook:
         ],
     )
     @mock.patch(BEAM_STRING.format("prepare_virtualenv"))
-    @mock.patch(BEAM_STRING.format("BeamCommandRunner"))
+    @mock.patch(BEAM_STRING.format("run_beam_command"))
     @mock.patch("airflow.providers.apache.beam.hooks.beam.subprocess.check_output", return_value=b"2.39.0")
     def test_start_python_pipeline_with_non_empty_py_requirements_and_without_system_packages(
         self,
@@ -163,7 +165,6 @@ class TestBeamHook:
         current_py_system_site_packages,
     ):
         hook = BeamHook(runner=DEFAULT_RUNNER)
-        wait_for_done = mock_runner.return_value.wait_for_done
         mock_virtualenv.return_value = "/dummy_dir/bin/python"
         process_line_callback = MagicMock()
 
@@ -185,9 +186,11 @@ class TestBeamHook:
             "--labels=foo=bar",
         ]
         mock_runner.assert_called_once_with(
-            cmd=expected_cmd, process_line_callback=process_line_callback, working_directory=None
+            cmd=expected_cmd,
+            process_line_callback=process_line_callback,
+            working_directory=None,
+            log=ANY,
         )
-        wait_for_done.assert_called_once_with()
         mock_virtualenv.assert_called_once_with(
             venv_directory=mock.ANY,
             python_bin="python3",
@@ -195,7 +198,7 @@ class TestBeamHook:
             requirements=current_py_requirements,
         )
 
-    @mock.patch(BEAM_STRING.format("BeamCommandRunner"))
+    @mock.patch(BEAM_STRING.format("run_beam_command"))
     @mock.patch("airflow.providers.apache.beam.hooks.beam.subprocess.check_output", return_value=b"2.39.0")
     def test_start_python_pipeline_with_empty_py_requirements_and_without_system_packages(
         self, mock_check_output, mock_runner
@@ -216,10 +219,9 @@ class TestBeamHook:
         mock_runner.assert_not_called()
         wait_for_done.assert_not_called()
 
-    @mock.patch(BEAM_STRING.format("BeamCommandRunner"))
+    @mock.patch(BEAM_STRING.format("run_beam_command"))
     def test_start_java_pipeline(self, mock_runner):
         hook = BeamHook(runner=DEFAULT_RUNNER)
-        wait_for_done = mock_runner.return_value.wait_for_done
         process_line_callback = MagicMock()
 
         hook.start_java_pipeline(
@@ -237,14 +239,12 @@ class TestBeamHook:
             '--labels={"foo":"bar"}',
         ]
         mock_runner.assert_called_once_with(
-            cmd=expected_cmd, process_line_callback=process_line_callback, working_directory=None
+            cmd=expected_cmd, process_line_callback=process_line_callback, working_directory=None, log=ANY
         )
-        wait_for_done.assert_called_once_with()
 
-    @mock.patch(BEAM_STRING.format("BeamCommandRunner"))
+    @mock.patch(BEAM_STRING.format("run_beam_command"))
     def test_start_java_pipeline_with_job_class(self, mock_runner):
         hook = BeamHook(runner=DEFAULT_RUNNER)
-        wait_for_done = mock_runner.return_value.wait_for_done
         process_line_callback = MagicMock()
 
         hook.start_java_pipeline(
@@ -264,16 +264,17 @@ class TestBeamHook:
             '--labels={"foo":"bar"}',
         ]
         mock_runner.assert_called_once_with(
-            cmd=expected_cmd, process_line_callback=process_line_callback, working_directory=None
+            cmd=expected_cmd,
+            process_line_callback=process_line_callback,
+            working_directory=None,
+            log=ANY,
         )
-        wait_for_done.assert_called_once_with()
 
     @mock.patch(BEAM_STRING.format("shutil.which"))
-    @mock.patch(BEAM_STRING.format("BeamCommandRunner"))
+    @mock.patch(BEAM_STRING.format("run_beam_command"))
     def test_start_go_pipeline(self, mock_runner, mock_which):
         mock_which.return_value = "/some_path/to/go"
         hook = BeamHook(runner=DEFAULT_RUNNER)
-        wait_for_done = mock_runner.return_value.wait_for_done
         process_line_callback = MagicMock()
 
         hook.start_go_pipeline(
@@ -293,9 +294,11 @@ class TestBeamHook:
             '--labels={"foo":"bar"}',
         ]
         mock_runner.assert_called_once_with(
-            cmd=expected_cmd, process_line_callback=process_line_callback, working_directory=go_workspace
+            cmd=expected_cmd,
+            process_line_callback=process_line_callback,
+            working_directory=go_workspace,
+            log=ANY,
         )
-        wait_for_done.assert_called_once_with()
 
     @mock.patch(BEAM_STRING.format("shutil.which"))
     def test_start_go_pipeline_without_go_installed_raises(self, mock_which):
@@ -312,10 +315,9 @@ class TestBeamHook:
                 variables=copy.deepcopy(BEAM_VARIABLES_GO),
             )
 
-    @mock.patch(BEAM_STRING.format("BeamCommandRunner"))
+    @mock.patch(BEAM_STRING.format("run_beam_command"))
     def test_start_go_pipeline_with_binary(self, mock_runner):
         hook = BeamHook(runner=DEFAULT_RUNNER)
-        wait_for_done_method = mock_runner.return_value.wait_for_done
         process_line_callback = MagicMock()
 
         launcher_binary = "/path/to/launcher-main"
@@ -337,17 +339,19 @@ class TestBeamHook:
         ]
 
         mock_runner.assert_called_once_with(
-            cmd=expected_cmd, process_line_callback=process_line_callback, working_directory=None
+            cmd=expected_cmd,
+            process_line_callback=process_line_callback,
+            working_directory=None,
+            log=ANY,
         )
-        wait_for_done_method.assert_called_once_with()
 
 
 class TestBeamRunner:
-    @mock.patch("airflow.providers.apache.beam.hooks.beam.BeamCommandRunner.log")
     @mock.patch("subprocess.Popen")
     @mock.patch("select.select")
-    def test_beam_wait_for_done_logging(self, mock_select, mock_popen, mock_logging):
+    def test_beam_wait_for_done_logging(self, mock_select, mock_popen):
         cmd = ["test", "cmd"]
+        mock_logging = MagicMock()
         mock_logging.info = MagicMock()
         mock_logging.warning = MagicMock()
         mock_proc = MagicMock()
@@ -365,13 +369,12 @@ class TestBeamRunner:
         mock_proc_poll.side_effect = [None, poll_resp_error]
         mock_proc.poll = mock_proc_poll
         mock_popen.return_value = mock_proc
-        beam = BeamCommandRunner(cmd)
-        mock_logging.info.assert_called_once_with("Running command: %s", " ".join(cmd))
-        mock_popen.assert_called_once_with(
-            cmd, shell=False, stdout=subprocess.PIPE, stderr=subprocess.PIPE, close_fds=True, cwd=None
-        )
         with pytest.raises(Exception):
-            beam.wait_for_done()
+            run_beam_command(cmd, None, None, mock_logging)
+            mock_logging.info.assert_called_once_with("Running command: %s", " ".join(cmd))
+            mock_popen.assert_called_once_with(
+                cmd, shell=False, stdout=subprocess.PIPE, stderr=subprocess.PIPE, close_fds=True, cwd=None
+            )
 
 
 class TestBeamOptionsToArgs:
diff --git a/tests/providers/google/cloud/hooks/test_dataflow.py b/tests/providers/google/cloud/hooks/test_dataflow.py
index 1e48eb6a56..9e91a97d44 100644
--- a/tests/providers/google/cloud/hooks/test_dataflow.py
+++ b/tests/providers/google/cloud/hooks/test_dataflow.py
@@ -30,7 +30,7 @@ import pytest
 from google.cloud.dataflow_v1beta3 import GetJobRequest, JobView
 
 from airflow.exceptions import AirflowException
-from airflow.providers.apache.beam.hooks.beam import BeamCommandRunner, BeamHook
+from airflow.providers.apache.beam.hooks.beam import BeamHook, run_beam_command
 from airflow.providers.google.cloud.hooks.dataflow import (
     DEFAULT_DATAFLOW_LOCATION,
     AsyncDataflowHook,
@@ -1849,9 +1849,12 @@ class TestDataflow:
             nonlocal found_job_id
             found_job_id = job_id
 
-        BeamCommandRunner(
-            cmd, process_line_callback=process_line_and_extract_dataflow_job_id_callback(callback)
-        ).wait_for_done()
+        mock_log = MagicMock()
+        run_beam_command(
+            cmd=cmd,
+            process_line_callback=process_line_and_extract_dataflow_job_id_callback(callback),
+            log=mock_log,
+        )
         assert found_job_id == TEST_JOB_ID
 
     def test_data_flow_missing_job_id(self):
@@ -1862,15 +1865,18 @@ class TestDataflow:
             nonlocal found_job_id
             found_job_id = job_id
 
-        BeamCommandRunner(
-            cmd, process_line_callback=process_line_and_extract_dataflow_job_id_callback(callback)
-        ).wait_for_done()
+        log = MagicMock()
+        run_beam_command(
+            cmd=cmd,
+            process_line_callback=process_line_and_extract_dataflow_job_id_callback(callback),
+            log=log,
+        )
         assert found_job_id is None
 
-    @mock.patch("airflow.providers.apache.beam.hooks.beam.BeamCommandRunner.log")
     @mock.patch("subprocess.Popen")
     @mock.patch("select.select")
-    def test_dataflow_wait_for_done_logging(self, mock_select, mock_popen, mock_logging):
+    def test_dataflow_wait_for_done_logging(self, mock_select, mock_popen):
+        mock_logging = MagicMock()
         mock_logging.info = MagicMock()
         mock_logging.warning = MagicMock()
         mock_proc = MagicMock()
@@ -1888,10 +1894,9 @@ class TestDataflow:
         mock_proc_poll.side_effect = [None, poll_resp_error]
         mock_proc.poll = mock_proc_poll
         mock_popen.return_value = mock_proc
-        dataflow = BeamCommandRunner(["test", "cmd"])
-        mock_logging.info.assert_called_once_with("Running command: %s", "test cmd")
         with pytest.raises(Exception):
-            dataflow.wait_for_done()
+            run_beam_command(cmd=["test", "cmd"], log=mock_logging)
+            mock_logging.info.assert_called_once_with("Running command: %s", "test cmd")
 
 
 @pytest.fixture()