You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by po...@apache.org on 2022/10/31 05:15:23 UTC

[airflow] branch main updated: DataflowStopJobOperator Operator (#27033)

This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 50d217a129 DataflowStopJobOperator Operator (#27033)
50d217a129 is described below

commit 50d217a1290f891be5d6be743b00b552fc10da20
Author: RafaƂ Grudowski <40...@users.noreply.github.com>
AuthorDate: Mon Oct 31 06:15:15 2022 +0100

    DataflowStopJobOperator Operator (#27033)
---
 .../google/cloud/example_dags/example_dataflow.py  | 25 ++++++
 airflow/providers/google/cloud/hooks/dataflow.py   | 25 +++---
 .../providers/google/cloud/operators/dataflow.py   | 93 ++++++++++++++++++++++
 .../operators/cloud/dataflow.rst                   | 18 +++++
 .../providers/google/cloud/hooks/test_dataflow.py  | 16 +---
 .../google/cloud/operators/test_dataflow.py        | 54 +++++++++++++
 6 files changed, 209 insertions(+), 22 deletions(-)

diff --git a/airflow/providers/google/cloud/example_dags/example_dataflow.py b/airflow/providers/google/cloud/example_dags/example_dataflow.py
index 1364de183e..b7b0c017dc 100644
--- a/airflow/providers/google/cloud/example_dags/example_dataflow.py
+++ b/airflow/providers/google/cloud/example_dags/example_dataflow.py
@@ -34,6 +34,7 @@ from airflow.providers.apache.beam.operators.beam import (
 from airflow.providers.google.cloud.hooks.dataflow import DataflowJobStatus
 from airflow.providers.google.cloud.operators.dataflow import (
     CheckJobRunning,
+    DataflowStopJobOperator,
     DataflowTemplatedJobStartOperator,
 )
 from airflow.providers.google.cloud.sensors.dataflow import (
@@ -261,3 +262,27 @@ with models.DAG(
         location="europe-west3",
     )
     # [END howto_operator_start_template_job]
+
+with models.DAG(
+    "example_gcp_stop_dataflow_job",
+    default_args=default_args,
+    start_date=START_DATE,
+    catchup=False,
+    tags=["example"],
+) as dag_template:
+    # [START howto_operator_stop_dataflow_job]
+    stop_dataflow_job = DataflowStopJobOperator(
+        task_id="stop-dataflow-job",
+        location="europe-west3",
+        job_name_prefix="start-template-job",
+    )
+    # [END howto_operator_stop_dataflow_job]
+    start_template_job = DataflowTemplatedJobStartOperator(
+        task_id="start-template-job",
+        template="gs://dataflow-templates/latest/Word_Count",
+        parameters={"inputFile": "gs://dataflow-samples/shakespeare/kinglear.txt", "output": GCS_OUTPUT},
+        location="europe-west3",
+        append_job_name=False,
+    )
+
+    stop_dataflow_job >> start_template_job
diff --git a/airflow/providers/google/cloud/hooks/dataflow.py b/airflow/providers/google/cloud/hooks/dataflow.py
index 490024b693..b9dbf9478c 100644
--- a/airflow/providers/google/cloud/hooks/dataflow.py
+++ b/airflow/providers/google/cloud/hooks/dataflow.py
@@ -237,6 +237,8 @@ class _DataflowJobsController(LoggingMixin):
         """
         if not self._multiple_jobs and self._job_id:
             return [self.fetch_job_by_id(self._job_id)]
+        elif self._jobs:
+            return [self.fetch_job_by_id(job["id"]) for job in self._jobs]
         elif self._job_name:
             jobs = self._fetch_jobs_by_prefix_name(self._job_name.lower())
             if len(jobs) == 1:
@@ -445,11 +447,11 @@ class _DataflowJobsController(LoggingMixin):
             job_states = {job["currentState"] for job in self._jobs}
             if not job_states.difference(expected_states):
                 return
-            unexpected_failed_end_states = expected_states - DataflowJobStatus.FAILED_END_STATES
+            unexpected_failed_end_states = DataflowJobStatus.FAILED_END_STATES - expected_states
             if unexpected_failed_end_states.intersection(job_states):
-                unexpected_failed_jobs = {
+                unexpected_failed_jobs = [
                     job for job in self._jobs if job["currentState"] in unexpected_failed_end_states
-                }
+                ]
                 raise AirflowException(
                     "Jobs failed: "
                     + ", ".join(
@@ -461,18 +463,19 @@ class _DataflowJobsController(LoggingMixin):
 
     def cancel(self) -> None:
         """Cancels or drains current job"""
-        jobs = self.get_jobs()
-        job_ids = [job["id"] for job in jobs if job["currentState"] not in DataflowJobStatus.TERMINAL_STATES]
+        self._jobs = [
+            job for job in self.get_jobs() if job["currentState"] not in DataflowJobStatus.TERMINAL_STATES
+        ]
+        job_ids = [job["id"] for job in self._jobs]
         if job_ids:
-            batch = self._dataflow.new_batch_http_request()
             self.log.info("Canceling jobs: %s", ", ".join(job_ids))
-            for job in jobs:
+            for job in self._jobs:
                 requested_state = (
                     DataflowJobStatus.JOB_STATE_DRAINED
                     if self.drain_pipeline and job["type"] == DataflowJobType.JOB_TYPE_STREAMING
                     else DataflowJobStatus.JOB_STATE_CANCELLED
                 )
-                batch.add(
+                request = (
                     self._dataflow.projects()
                     .locations()
                     .jobs()
@@ -483,14 +486,16 @@ class _DataflowJobsController(LoggingMixin):
                         body={"requestedState": requested_state},
                     )
                 )
-            batch.execute()
+                request.execute(num_retries=self._num_retries)
             if self._cancel_timeout and isinstance(self._cancel_timeout, int):
                 timeout_error_message = (
                     f"Canceling jobs failed due to timeout ({self._cancel_timeout}s): {', '.join(job_ids)}"
                 )
                 tm = timeout(seconds=self._cancel_timeout, error_message=timeout_error_message)
                 with tm:
-                    self._wait_for_states({DataflowJobStatus.JOB_STATE_CANCELLED})
+                    self._wait_for_states(
+                        {DataflowJobStatus.JOB_STATE_CANCELLED, DataflowJobStatus.JOB_STATE_DRAINED}
+                    )
         else:
             self.log.info("No jobs to cancel")
 
diff --git a/airflow/providers/google/cloud/operators/dataflow.py b/airflow/providers/google/cloud/operators/dataflow.py
index 6d0a0412ae..7cfab6bbff 100644
--- a/airflow/providers/google/cloud/operators/dataflow.py
+++ b/airflow/providers/google/cloud/operators/dataflow.py
@@ -1137,3 +1137,96 @@ class DataflowCreatePythonJobOperator(BaseOperator):
             self.dataflow_hook.cancel_job(
                 job_id=self.job_id, project_id=self.project_id or self.dataflow_hook.project_id
             )
+
+
+class DataflowStopJobOperator(BaseOperator):
+    """
+    Stops the job with the specified name prefix or Job ID.
+    All jobs with provided name prefix will be stopped.
+    Streaming jobs are drained by default.
+
+    Parameter ``job_name_prefix`` and ``job_id`` are mutually exclusive.
+
+    .. seealso::
+        For more details on stopping a pipeline see:
+        https://cloud.google.com/dataflow/docs/guides/stopping-a-pipeline
+
+    .. seealso::
+        For more information on how to use this operator, take a look at the guide:
+        :ref:`howto/operator:DataflowStopJobOperator`
+
+    :param job_name_prefix: Name prefix specifying which jobs are to be stopped.
+    :param job_id: Job ID specifying which jobs are to be stopped.
+    :param project_id: Optional, the Google Cloud project ID in which to start a job.
+        If set to None or missing, the default project_id from the Google Cloud connection is used.
+    :param location: Optional, Job location. If set to None or missing, "us-central1" will be used.
+    :param gcp_conn_id: The connection ID to use connecting to Google Cloud.
+    :param delegate_to: The account to impersonate using domain-wide delegation of authority,
+        if any. For this to work, the service account making the request must have
+        domain-wide delegation enabled.
+    :param poll_sleep: The time in seconds to sleep between polling Google
+        Cloud Platform for the dataflow job status to confirm it's stopped.
+    :param impersonation_chain: Optional service account to impersonate using short-term
+        credentials, or chained list of accounts required to get the access_token
+        of the last account in the list, which will be impersonated in the request.
+        If set as a string, the account must grant the originating account
+        the Service Account Token Creator IAM role.
+        If set as a sequence, the identities from the list must grant
+        Service Account Token Creator IAM role to the directly preceding identity, with first
+        account from the list granting this role to the originating account (templated).
+    :param drain_pipeline: Optional, set to False if want to stop streaming job by canceling it
+        instead of draining. See: https://cloud.google.com/dataflow/docs/guides/stopping-a-pipeline
+    :param stop_timeout: wait time in seconds for successful job canceling/draining
+    """
+
+    def __init__(
+        self,
+        job_name_prefix: str | None = None,
+        job_id: str | None = None,
+        project_id: str | None = None,
+        location: str = DEFAULT_DATAFLOW_LOCATION,
+        gcp_conn_id: str = "google_cloud_default",
+        delegate_to: str | None = None,
+        poll_sleep: int = 10,
+        impersonation_chain: str | Sequence[str] | None = None,
+        stop_timeout: int | None = 10 * 60,
+        drain_pipeline: bool = True,
+        **kwargs,
+    ) -> None:
+        super().__init__(**kwargs)
+        self.poll_sleep = poll_sleep
+        self.stop_timeout = stop_timeout
+        self.job_name = job_name_prefix
+        self.job_id = job_id
+        self.project_id = project_id
+        self.location = location
+        self.gcp_conn_id = gcp_conn_id
+        self.delegate_to = delegate_to
+        self.impersonation_chain = impersonation_chain
+        self.hook: DataflowHook | None = None
+        self.drain_pipeline = drain_pipeline
+
+    def execute(self, context: Context) -> None:
+        self.dataflow_hook = DataflowHook(
+            gcp_conn_id=self.gcp_conn_id,
+            delegate_to=self.delegate_to,
+            poll_sleep=self.poll_sleep,
+            impersonation_chain=self.impersonation_chain,
+            cancel_timeout=self.stop_timeout,
+            drain_pipeline=self.drain_pipeline,
+        )
+        if self.job_id or self.dataflow_hook.is_job_dataflow_running(
+            name=self.job_name,
+            project_id=self.project_id,
+            location=self.location,
+        ):
+            self.dataflow_hook.cancel_job(
+                job_name=self.job_name,
+                project_id=self.project_id,
+                location=self.location,
+                job_id=self.job_id,
+            )
+        else:
+            self.log.info("No jobs to stop")
+
+        return None
diff --git a/docs/apache-airflow-providers-google/operators/cloud/dataflow.rst b/docs/apache-airflow-providers-google/operators/cloud/dataflow.rst
index 32ac462f0c..76829275a2 100644
--- a/docs/apache-airflow-providers-google/operators/cloud/dataflow.rst
+++ b/docs/apache-airflow-providers-google/operators/cloud/dataflow.rst
@@ -238,6 +238,24 @@ Here is an example of running Dataflow SQL job with
 See the `Dataflow SQL reference
 <https://cloud.google.com/dataflow/docs/reference/sql>`_.
 
+.. _howto/operator:DataflowStopJobOperator:
+
+Stopping a pipeline
+^^^^^^^^^^^^^^^^^^^
+To stop one or more Dataflow pipelines you can use
+:class:`~airflow.providers.google.cloud.operators.dataflow.DataflowStopJobOperator`.
+Streaming pipelines are drained by default, setting ``drain_pipeline`` to ``False`` will cancel them instead.
+Provide ``job_id`` to stop a specific job, or ``job_name_prefix`` to stop all jobs with provided name prefix.
+
+.. exampleinclude:: /../../airflow/providers/google/cloud/example_dags/example_dataflow.py
+    :language: python
+    :dedent: 4
+    :start-after: [START howto_operator_stop_dataflow_job]
+    :end-before: [END howto_operator_stop_dataflow_job]
+
+See: `Stopping a running pipeline
+<https://cloud.google.com/dataflow/docs/guides/stopping-a-pipeline>`_.
+
 .. _howto/operator:DataflowJobStatusSensor:
 .. _howto/operator:DataflowJobMetricsSensor:
 .. _howto/operator:DataflowJobMessagesSensor:
diff --git a/tests/providers/google/cloud/hooks/test_dataflow.py b/tests/providers/google/cloud/hooks/test_dataflow.py
index d2db3b8e1b..ae7478880b 100644
--- a/tests/providers/google/cloud/hooks/test_dataflow.py
+++ b/tests/providers/google/cloud/hooks/test_dataflow.py
@@ -1505,8 +1505,6 @@ class TestDataflowJob(unittest.TestCase):
         get_method.assert_called_with(jobId=TEST_JOB_ID, location=TEST_LOCATION, projectId=TEST_PROJECT)
         get_method.return_value.execute.assert_called_with(num_retries=20)
 
-        self.mock_dataflow.new_batch_http_request.assert_called_once_with()
-        mock_batch = self.mock_dataflow.new_batch_http_request.return_value
         mock_update = mock_jobs.return_value.update
         mock_update.assert_called_once_with(
             body={"requestedState": "JOB_STATE_CANCELLED"},
@@ -1514,7 +1512,7 @@ class TestDataflowJob(unittest.TestCase):
             location=TEST_LOCATION,
             projectId="test-project",
         )
-        mock_batch.add.assert_called_once_with(mock_update.return_value)
+        mock_update.return_value.execute.assert_called_once_with(num_retries=20)
 
     @mock.patch("airflow.providers.google.cloud.hooks.dataflow.timeout")
     @mock.patch("time.sleep")
@@ -1546,8 +1544,6 @@ class TestDataflowJob(unittest.TestCase):
         get_method.assert_called_with(jobId=TEST_JOB_ID, location=TEST_LOCATION, projectId=TEST_PROJECT)
         get_method.return_value.execute.assert_called_with(num_retries=20)
 
-        self.mock_dataflow.new_batch_http_request.assert_called_once_with()
-        mock_batch = self.mock_dataflow.new_batch_http_request.return_value
         mock_update = mock_jobs.return_value.update
         mock_update.assert_called_once_with(
             body={"requestedState": "JOB_STATE_CANCELLED"},
@@ -1555,7 +1551,8 @@ class TestDataflowJob(unittest.TestCase):
             location=TEST_LOCATION,
             projectId="test-project",
         )
-        mock_batch.add.assert_called_once_with(mock_update.return_value)
+        mock_update.return_value.execute.assert_called_once_with(num_retries=20)
+
         mock_sleep.assert_has_calls([mock.call(4), mock.call(4), mock.call(4)])
         mock_timeout.assert_called_once_with(
             seconds=10, error_message="Canceling jobs failed due to timeout (10s): test-job-id"
@@ -1603,9 +1600,6 @@ class TestDataflowJob(unittest.TestCase):
 
         get_method.return_value.execute.assert_called_once_with(num_retries=20)
 
-        self.mock_dataflow.new_batch_http_request.assert_called_once_with()
-
-        mock_batch = self.mock_dataflow.new_batch_http_request.return_value
         mock_update = self.mock_dataflow.projects.return_value.locations.return_value.jobs.return_value.update
         mock_update.assert_called_once_with(
             body={"requestedState": requested_state},
@@ -1613,8 +1607,7 @@ class TestDataflowJob(unittest.TestCase):
             location=TEST_LOCATION,
             projectId="test-project",
         )
-        mock_batch.add.assert_called_once_with(mock_update.return_value)
-        mock_batch.execute.assert_called_once()
+        mock_update.return_value.execute.assert_called_once_with(num_retries=20)
 
     def test_dataflow_job_cancel_job_no_running_jobs(self):
         mock_jobs = self.mock_dataflow.projects.return_value.locations.return_value.jobs
@@ -1643,7 +1636,6 @@ class TestDataflowJob(unittest.TestCase):
         get_method.assert_called_with(jobId=TEST_JOB_ID, location=TEST_LOCATION, projectId=TEST_PROJECT)
         get_method.return_value.execute.assert_called_with(num_retries=20)
 
-        self.mock_dataflow.new_batch_http_request.assert_not_called()
         mock_jobs.return_value.update.assert_not_called()
 
     def test_fetch_list_job_messages_responses(self):
diff --git a/tests/providers/google/cloud/operators/test_dataflow.py b/tests/providers/google/cloud/operators/test_dataflow.py
index 1a0f05d404..40e37c9487 100644
--- a/tests/providers/google/cloud/operators/test_dataflow.py
+++ b/tests/providers/google/cloud/operators/test_dataflow.py
@@ -29,6 +29,7 @@ from airflow.providers.google.cloud.operators.dataflow import (
     DataflowCreatePythonJobOperator,
     DataflowStartFlexTemplateOperator,
     DataflowStartSqlJobOperator,
+    DataflowStopJobOperator,
     DataflowTemplatedJobStartOperator,
 )
 from airflow.version import version
@@ -561,3 +562,56 @@ class TestDataflowSqlOperator(unittest.TestCase):
         mock_hook.return_value.cancel_job.assert_called_once_with(
             job_id="test-job-id", project_id=None, location=None
         )
+
+
+class TestDataflowStopJobOperator(unittest.TestCase):
+    @mock.patch("airflow.providers.google.cloud.operators.dataflow.DataflowHook")
+    def test_exec_job_id(self, dataflow_mock):
+        self.dataflow = DataflowStopJobOperator(
+            task_id=TASK_ID,
+            project_id=TEST_PROJECT,
+            job_id=JOB_ID,
+            poll_sleep=POLL_SLEEP,
+            location=TEST_LOCATION,
+        )
+        """
+        Test DataflowHook is created and the right args are passed to cancel_job.
+        """
+        cancel_job_hook = dataflow_mock.return_value.cancel_job
+        self.dataflow.execute(None)
+        assert dataflow_mock.called
+        cancel_job_hook.assert_called_once_with(
+            job_name=None,
+            project_id=TEST_PROJECT,
+            location=TEST_LOCATION,
+            job_id=JOB_ID,
+        )
+
+    @mock.patch("airflow.providers.google.cloud.operators.dataflow.DataflowHook")
+    def test_exec_job_name_prefix(self, dataflow_mock):
+        self.dataflow = DataflowStopJobOperator(
+            task_id=TASK_ID,
+            project_id=TEST_PROJECT,
+            job_name_prefix=JOB_NAME,
+            poll_sleep=POLL_SLEEP,
+            location=TEST_LOCATION,
+        )
+        """
+        Test DataflowHook is created and the right args are passed to cancel_job
+        and is_job_dataflow_running.
+        """
+        is_job_running_hook = dataflow_mock.return_value.is_job_dataflow_running
+        cancel_job_hook = dataflow_mock.return_value.cancel_job
+        self.dataflow.execute(None)
+        assert dataflow_mock.called
+        is_job_running_hook.assert_called_once_with(
+            name=JOB_NAME,
+            project_id=TEST_PROJECT,
+            location=TEST_LOCATION,
+        )
+        cancel_job_hook.assert_called_once_with(
+            job_name=JOB_NAME,
+            project_id=TEST_PROJECT,
+            location=TEST_LOCATION,
+            job_id=None,
+        )