You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by po...@apache.org on 2022/10/31 05:15:23 UTC
[airflow] branch main updated: DataflowStopJobOperator Operator (#27033)
This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 50d217a129 DataflowStopJobOperator Operator (#27033)
50d217a129 is described below
commit 50d217a1290f891be5d6be743b00b552fc10da20
Author: RafaĆ Grudowski <40...@users.noreply.github.com>
AuthorDate: Mon Oct 31 06:15:15 2022 +0100
DataflowStopJobOperator Operator (#27033)
---
.../google/cloud/example_dags/example_dataflow.py | 25 ++++++
airflow/providers/google/cloud/hooks/dataflow.py | 25 +++---
.../providers/google/cloud/operators/dataflow.py | 93 ++++++++++++++++++++++
.../operators/cloud/dataflow.rst | 18 +++++
.../providers/google/cloud/hooks/test_dataflow.py | 16 +---
.../google/cloud/operators/test_dataflow.py | 54 +++++++++++++
6 files changed, 209 insertions(+), 22 deletions(-)
diff --git a/airflow/providers/google/cloud/example_dags/example_dataflow.py b/airflow/providers/google/cloud/example_dags/example_dataflow.py
index 1364de183e..b7b0c017dc 100644
--- a/airflow/providers/google/cloud/example_dags/example_dataflow.py
+++ b/airflow/providers/google/cloud/example_dags/example_dataflow.py
@@ -34,6 +34,7 @@ from airflow.providers.apache.beam.operators.beam import (
from airflow.providers.google.cloud.hooks.dataflow import DataflowJobStatus
from airflow.providers.google.cloud.operators.dataflow import (
CheckJobRunning,
+ DataflowStopJobOperator,
DataflowTemplatedJobStartOperator,
)
from airflow.providers.google.cloud.sensors.dataflow import (
@@ -261,3 +262,27 @@ with models.DAG(
location="europe-west3",
)
# [END howto_operator_start_template_job]
+
+with models.DAG(
+ "example_gcp_stop_dataflow_job",
+ default_args=default_args,
+ start_date=START_DATE,
+ catchup=False,
+ tags=["example"],
+) as dag_template:
+ # [START howto_operator_stop_dataflow_job]
+ stop_dataflow_job = DataflowStopJobOperator(
+ task_id="stop-dataflow-job",
+ location="europe-west3",
+ job_name_prefix="start-template-job",
+ )
+ # [END howto_operator_stop_dataflow_job]
+ start_template_job = DataflowTemplatedJobStartOperator(
+ task_id="start-template-job",
+ template="gs://dataflow-templates/latest/Word_Count",
+ parameters={"inputFile": "gs://dataflow-samples/shakespeare/kinglear.txt", "output": GCS_OUTPUT},
+ location="europe-west3",
+ append_job_name=False,
+ )
+
+ stop_dataflow_job >> start_template_job
diff --git a/airflow/providers/google/cloud/hooks/dataflow.py b/airflow/providers/google/cloud/hooks/dataflow.py
index 490024b693..b9dbf9478c 100644
--- a/airflow/providers/google/cloud/hooks/dataflow.py
+++ b/airflow/providers/google/cloud/hooks/dataflow.py
@@ -237,6 +237,8 @@ class _DataflowJobsController(LoggingMixin):
"""
if not self._multiple_jobs and self._job_id:
return [self.fetch_job_by_id(self._job_id)]
+ elif self._jobs:
+ return [self.fetch_job_by_id(job["id"]) for job in self._jobs]
elif self._job_name:
jobs = self._fetch_jobs_by_prefix_name(self._job_name.lower())
if len(jobs) == 1:
@@ -445,11 +447,11 @@ class _DataflowJobsController(LoggingMixin):
job_states = {job["currentState"] for job in self._jobs}
if not job_states.difference(expected_states):
return
- unexpected_failed_end_states = expected_states - DataflowJobStatus.FAILED_END_STATES
+ unexpected_failed_end_states = DataflowJobStatus.FAILED_END_STATES - expected_states
if unexpected_failed_end_states.intersection(job_states):
- unexpected_failed_jobs = {
+ unexpected_failed_jobs = [
job for job in self._jobs if job["currentState"] in unexpected_failed_end_states
- }
+ ]
raise AirflowException(
"Jobs failed: "
+ ", ".join(
@@ -461,18 +463,19 @@ class _DataflowJobsController(LoggingMixin):
def cancel(self) -> None:
"""Cancels or drains current job"""
- jobs = self.get_jobs()
- job_ids = [job["id"] for job in jobs if job["currentState"] not in DataflowJobStatus.TERMINAL_STATES]
+ self._jobs = [
+ job for job in self.get_jobs() if job["currentState"] not in DataflowJobStatus.TERMINAL_STATES
+ ]
+ job_ids = [job["id"] for job in self._jobs]
if job_ids:
- batch = self._dataflow.new_batch_http_request()
self.log.info("Canceling jobs: %s", ", ".join(job_ids))
- for job in jobs:
+ for job in self._jobs:
requested_state = (
DataflowJobStatus.JOB_STATE_DRAINED
if self.drain_pipeline and job["type"] == DataflowJobType.JOB_TYPE_STREAMING
else DataflowJobStatus.JOB_STATE_CANCELLED
)
- batch.add(
+ request = (
self._dataflow.projects()
.locations()
.jobs()
@@ -483,14 +486,16 @@ class _DataflowJobsController(LoggingMixin):
body={"requestedState": requested_state},
)
)
- batch.execute()
+ request.execute(num_retries=self._num_retries)
if self._cancel_timeout and isinstance(self._cancel_timeout, int):
timeout_error_message = (
f"Canceling jobs failed due to timeout ({self._cancel_timeout}s): {', '.join(job_ids)}"
)
tm = timeout(seconds=self._cancel_timeout, error_message=timeout_error_message)
with tm:
- self._wait_for_states({DataflowJobStatus.JOB_STATE_CANCELLED})
+ self._wait_for_states(
+ {DataflowJobStatus.JOB_STATE_CANCELLED, DataflowJobStatus.JOB_STATE_DRAINED}
+ )
else:
self.log.info("No jobs to cancel")
diff --git a/airflow/providers/google/cloud/operators/dataflow.py b/airflow/providers/google/cloud/operators/dataflow.py
index 6d0a0412ae..7cfab6bbff 100644
--- a/airflow/providers/google/cloud/operators/dataflow.py
+++ b/airflow/providers/google/cloud/operators/dataflow.py
@@ -1137,3 +1137,96 @@ class DataflowCreatePythonJobOperator(BaseOperator):
self.dataflow_hook.cancel_job(
job_id=self.job_id, project_id=self.project_id or self.dataflow_hook.project_id
)
+
+
+class DataflowStopJobOperator(BaseOperator):
+ """
+ Stops the job with the specified name prefix or Job ID.
+ All jobs with provided name prefix will be stopped.
+ Streaming jobs are drained by default.
+
+ Parameter ``job_name_prefix`` and ``job_id`` are mutually exclusive.
+
+ .. seealso::
+ For more details on stopping a pipeline see:
+ https://cloud.google.com/dataflow/docs/guides/stopping-a-pipeline
+
+ .. seealso::
+ For more information on how to use this operator, take a look at the guide:
+ :ref:`howto/operator:DataflowStopJobOperator`
+
+ :param job_name_prefix: Name prefix specifying which jobs are to be stopped.
+ :param job_id: Job ID specifying which jobs are to be stopped.
+ :param project_id: Optional, the Google Cloud project ID in which to start a job.
+ If set to None or missing, the default project_id from the Google Cloud connection is used.
+ :param location: Optional, Job location. If set to None or missing, "us-central1" will be used.
+ :param gcp_conn_id: The connection ID to use connecting to Google Cloud.
+ :param delegate_to: The account to impersonate using domain-wide delegation of authority,
+ if any. For this to work, the service account making the request must have
+ domain-wide delegation enabled.
+ :param poll_sleep: The time in seconds to sleep between polling Google
+ Cloud Platform for the dataflow job status to confirm it's stopped.
+ :param impersonation_chain: Optional service account to impersonate using short-term
+ credentials, or chained list of accounts required to get the access_token
+ of the last account in the list, which will be impersonated in the request.
+ If set as a string, the account must grant the originating account
+ the Service Account Token Creator IAM role.
+ If set as a sequence, the identities from the list must grant
+ Service Account Token Creator IAM role to the directly preceding identity, with first
+ account from the list granting this role to the originating account (templated).
+ :param drain_pipeline: Optional, set to False if want to stop streaming job by canceling it
+ instead of draining. See: https://cloud.google.com/dataflow/docs/guides/stopping-a-pipeline
+ :param stop_timeout: wait time in seconds for successful job canceling/draining
+ """
+
+ def __init__(
+ self,
+ job_name_prefix: str | None = None,
+ job_id: str | None = None,
+ project_id: str | None = None,
+ location: str = DEFAULT_DATAFLOW_LOCATION,
+ gcp_conn_id: str = "google_cloud_default",
+ delegate_to: str | None = None,
+ poll_sleep: int = 10,
+ impersonation_chain: str | Sequence[str] | None = None,
+ stop_timeout: int | None = 10 * 60,
+ drain_pipeline: bool = True,
+ **kwargs,
+ ) -> None:
+ super().__init__(**kwargs)
+ self.poll_sleep = poll_sleep
+ self.stop_timeout = stop_timeout
+ self.job_name = job_name_prefix
+ self.job_id = job_id
+ self.project_id = project_id
+ self.location = location
+ self.gcp_conn_id = gcp_conn_id
+ self.delegate_to = delegate_to
+ self.impersonation_chain = impersonation_chain
+ self.hook: DataflowHook | None = None
+ self.drain_pipeline = drain_pipeline
+
+ def execute(self, context: Context) -> None:
+ self.dataflow_hook = DataflowHook(
+ gcp_conn_id=self.gcp_conn_id,
+ delegate_to=self.delegate_to,
+ poll_sleep=self.poll_sleep,
+ impersonation_chain=self.impersonation_chain,
+ cancel_timeout=self.stop_timeout,
+ drain_pipeline=self.drain_pipeline,
+ )
+ if self.job_id or self.dataflow_hook.is_job_dataflow_running(
+ name=self.job_name,
+ project_id=self.project_id,
+ location=self.location,
+ ):
+ self.dataflow_hook.cancel_job(
+ job_name=self.job_name,
+ project_id=self.project_id,
+ location=self.location,
+ job_id=self.job_id,
+ )
+ else:
+ self.log.info("No jobs to stop")
+
+ return None
diff --git a/docs/apache-airflow-providers-google/operators/cloud/dataflow.rst b/docs/apache-airflow-providers-google/operators/cloud/dataflow.rst
index 32ac462f0c..76829275a2 100644
--- a/docs/apache-airflow-providers-google/operators/cloud/dataflow.rst
+++ b/docs/apache-airflow-providers-google/operators/cloud/dataflow.rst
@@ -238,6 +238,24 @@ Here is an example of running Dataflow SQL job with
See the `Dataflow SQL reference
<https://cloud.google.com/dataflow/docs/reference/sql>`_.
+.. _howto/operator:DataflowStopJobOperator:
+
+Stopping a pipeline
+^^^^^^^^^^^^^^^^^^^
+To stop one or more Dataflow pipelines you can use
+:class:`~airflow.providers.google.cloud.operators.dataflow.DataflowStopJobOperator`.
+Streaming pipelines are drained by default, setting ``drain_pipeline`` to ``False`` will cancel them instead.
+Provide ``job_id`` to stop a specific job, or ``job_name_prefix`` to stop all jobs with provided name prefix.
+
+.. exampleinclude:: /../../airflow/providers/google/cloud/example_dags/example_dataflow.py
+ :language: python
+ :dedent: 4
+ :start-after: [START howto_operator_stop_dataflow_job]
+ :end-before: [END howto_operator_stop_dataflow_job]
+
+See: `Stopping a running pipeline
+<https://cloud.google.com/dataflow/docs/guides/stopping-a-pipeline>`_.
+
.. _howto/operator:DataflowJobStatusSensor:
.. _howto/operator:DataflowJobMetricsSensor:
.. _howto/operator:DataflowJobMessagesSensor:
diff --git a/tests/providers/google/cloud/hooks/test_dataflow.py b/tests/providers/google/cloud/hooks/test_dataflow.py
index d2db3b8e1b..ae7478880b 100644
--- a/tests/providers/google/cloud/hooks/test_dataflow.py
+++ b/tests/providers/google/cloud/hooks/test_dataflow.py
@@ -1505,8 +1505,6 @@ class TestDataflowJob(unittest.TestCase):
get_method.assert_called_with(jobId=TEST_JOB_ID, location=TEST_LOCATION, projectId=TEST_PROJECT)
get_method.return_value.execute.assert_called_with(num_retries=20)
- self.mock_dataflow.new_batch_http_request.assert_called_once_with()
- mock_batch = self.mock_dataflow.new_batch_http_request.return_value
mock_update = mock_jobs.return_value.update
mock_update.assert_called_once_with(
body={"requestedState": "JOB_STATE_CANCELLED"},
@@ -1514,7 +1512,7 @@ class TestDataflowJob(unittest.TestCase):
location=TEST_LOCATION,
projectId="test-project",
)
- mock_batch.add.assert_called_once_with(mock_update.return_value)
+ mock_update.return_value.execute.assert_called_once_with(num_retries=20)
@mock.patch("airflow.providers.google.cloud.hooks.dataflow.timeout")
@mock.patch("time.sleep")
@@ -1546,8 +1544,6 @@ class TestDataflowJob(unittest.TestCase):
get_method.assert_called_with(jobId=TEST_JOB_ID, location=TEST_LOCATION, projectId=TEST_PROJECT)
get_method.return_value.execute.assert_called_with(num_retries=20)
- self.mock_dataflow.new_batch_http_request.assert_called_once_with()
- mock_batch = self.mock_dataflow.new_batch_http_request.return_value
mock_update = mock_jobs.return_value.update
mock_update.assert_called_once_with(
body={"requestedState": "JOB_STATE_CANCELLED"},
@@ -1555,7 +1551,8 @@ class TestDataflowJob(unittest.TestCase):
location=TEST_LOCATION,
projectId="test-project",
)
- mock_batch.add.assert_called_once_with(mock_update.return_value)
+ mock_update.return_value.execute.assert_called_once_with(num_retries=20)
+
mock_sleep.assert_has_calls([mock.call(4), mock.call(4), mock.call(4)])
mock_timeout.assert_called_once_with(
seconds=10, error_message="Canceling jobs failed due to timeout (10s): test-job-id"
@@ -1603,9 +1600,6 @@ class TestDataflowJob(unittest.TestCase):
get_method.return_value.execute.assert_called_once_with(num_retries=20)
- self.mock_dataflow.new_batch_http_request.assert_called_once_with()
-
- mock_batch = self.mock_dataflow.new_batch_http_request.return_value
mock_update = self.mock_dataflow.projects.return_value.locations.return_value.jobs.return_value.update
mock_update.assert_called_once_with(
body={"requestedState": requested_state},
@@ -1613,8 +1607,7 @@ class TestDataflowJob(unittest.TestCase):
location=TEST_LOCATION,
projectId="test-project",
)
- mock_batch.add.assert_called_once_with(mock_update.return_value)
- mock_batch.execute.assert_called_once()
+ mock_update.return_value.execute.assert_called_once_with(num_retries=20)
def test_dataflow_job_cancel_job_no_running_jobs(self):
mock_jobs = self.mock_dataflow.projects.return_value.locations.return_value.jobs
@@ -1643,7 +1636,6 @@ class TestDataflowJob(unittest.TestCase):
get_method.assert_called_with(jobId=TEST_JOB_ID, location=TEST_LOCATION, projectId=TEST_PROJECT)
get_method.return_value.execute.assert_called_with(num_retries=20)
- self.mock_dataflow.new_batch_http_request.assert_not_called()
mock_jobs.return_value.update.assert_not_called()
def test_fetch_list_job_messages_responses(self):
diff --git a/tests/providers/google/cloud/operators/test_dataflow.py b/tests/providers/google/cloud/operators/test_dataflow.py
index 1a0f05d404..40e37c9487 100644
--- a/tests/providers/google/cloud/operators/test_dataflow.py
+++ b/tests/providers/google/cloud/operators/test_dataflow.py
@@ -29,6 +29,7 @@ from airflow.providers.google.cloud.operators.dataflow import (
DataflowCreatePythonJobOperator,
DataflowStartFlexTemplateOperator,
DataflowStartSqlJobOperator,
+ DataflowStopJobOperator,
DataflowTemplatedJobStartOperator,
)
from airflow.version import version
@@ -561,3 +562,56 @@ class TestDataflowSqlOperator(unittest.TestCase):
mock_hook.return_value.cancel_job.assert_called_once_with(
job_id="test-job-id", project_id=None, location=None
)
+
+
+class TestDataflowStopJobOperator(unittest.TestCase):
+ @mock.patch("airflow.providers.google.cloud.operators.dataflow.DataflowHook")
+ def test_exec_job_id(self, dataflow_mock):
+ self.dataflow = DataflowStopJobOperator(
+ task_id=TASK_ID,
+ project_id=TEST_PROJECT,
+ job_id=JOB_ID,
+ poll_sleep=POLL_SLEEP,
+ location=TEST_LOCATION,
+ )
+ """
+ Test DataflowHook is created and the right args are passed to cancel_job.
+ """
+ cancel_job_hook = dataflow_mock.return_value.cancel_job
+ self.dataflow.execute(None)
+ assert dataflow_mock.called
+ cancel_job_hook.assert_called_once_with(
+ job_name=None,
+ project_id=TEST_PROJECT,
+ location=TEST_LOCATION,
+ job_id=JOB_ID,
+ )
+
+ @mock.patch("airflow.providers.google.cloud.operators.dataflow.DataflowHook")
+ def test_exec_job_name_prefix(self, dataflow_mock):
+ self.dataflow = DataflowStopJobOperator(
+ task_id=TASK_ID,
+ project_id=TEST_PROJECT,
+ job_name_prefix=JOB_NAME,
+ poll_sleep=POLL_SLEEP,
+ location=TEST_LOCATION,
+ )
+ """
+ Test DataflowHook is created and the right args are passed to cancel_job
+ and is_job_dataflow_running.
+ """
+ is_job_running_hook = dataflow_mock.return_value.is_job_dataflow_running
+ cancel_job_hook = dataflow_mock.return_value.cancel_job
+ self.dataflow.execute(None)
+ assert dataflow_mock.called
+ is_job_running_hook.assert_called_once_with(
+ name=JOB_NAME,
+ project_id=TEST_PROJECT,
+ location=TEST_LOCATION,
+ )
+ cancel_job_hook.assert_called_once_with(
+ job_name=JOB_NAME,
+ project_id=TEST_PROJECT,
+ location=TEST_LOCATION,
+ job_id=None,
+ )