You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by po...@apache.org on 2022/04/14 14:38:04 UTC
[airflow] branch main updated: Fix cancel_on_kill after execution timeout for DataprocSubmitJobOperator (#22955)
This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new ea1ae1963e Fix cancel_on_kill after execution timeout for DataprocSubmitJobOperator (#22955)
ea1ae1963e is described below
commit ea1ae1963ecf1b543e4f5e8deb59d623df42d44a
Author: Krzysztof Mioduszewski <km...@gmail.com>
AuthorDate: Thu Apr 14 16:37:54 2022 +0200
Fix cancel_on_kill after execution timeout for DataprocSubmitJobOperator (#22955)
Synchronous tasks killed by execution timeout weren't canceled
due to wrong assignment of job_id property.
---
.../providers/google/cloud/operators/dataproc.py | 16 ++++++------
.../google/cloud/operators/test_dataproc.py | 29 ++++++++++++++++++++++
2 files changed, 38 insertions(+), 7 deletions(-)
diff --git a/airflow/providers/google/cloud/operators/dataproc.py b/airflow/providers/google/cloud/operators/dataproc.py
index 2b69dfdeb4..423f7e8fe7 100644
--- a/airflow/providers/google/cloud/operators/dataproc.py
+++ b/airflow/providers/google/cloud/operators/dataproc.py
@@ -1858,19 +1858,21 @@ class DataprocSubmitJobOperator(BaseOperator):
timeout=self.timeout,
metadata=self.metadata,
)
- job_id = job_object.reference.job_id
- self.log.info('Job %s submitted successfully.', job_id)
+ new_job_id: str = job_object.reference.job_id
+ self.log.info('Job %s submitted successfully.', new_job_id)
# Save data required by extra links no matter what the job status will be
- DataprocLink.persist(context=context, task_instance=self, url=DATAPROC_JOB_LOG_LINK, resource=job_id)
+ DataprocLink.persist(
+ context=context, task_instance=self, url=DATAPROC_JOB_LOG_LINK, resource=new_job_id
+ )
+ self.job_id = new_job_id
if not self.asynchronous:
- self.log.info('Waiting for job %s to complete', job_id)
+ self.log.info('Waiting for job %s to complete', new_job_id)
self.hook.wait_for_job(
- job_id=job_id, region=self.region, project_id=self.project_id, timeout=self.wait_timeout
+ job_id=new_job_id, region=self.region, project_id=self.project_id, timeout=self.wait_timeout
)
- self.log.info('Job %s completed successfully.', job_id)
+ self.log.info('Job %s completed successfully.', new_job_id)
- self.job_id = job_id
return self.job_id
def on_kill(self):
diff --git a/tests/providers/google/cloud/operators/test_dataproc.py b/tests/providers/google/cloud/operators/test_dataproc.py
index 527765135c..ec2392f168 100644
--- a/tests/providers/google/cloud/operators/test_dataproc.py
+++ b/tests/providers/google/cloud/operators/test_dataproc.py
@@ -25,6 +25,7 @@ from google.api_core.exceptions import AlreadyExists, NotFound
from google.api_core.retry import Retry
from airflow import AirflowException
+from airflow.exceptions import AirflowTaskTimeout
from airflow.models import DAG, DagBag
from airflow.providers.google.cloud.operators.dataproc import (
DATAPROC_CLUSTER_LINK,
@@ -877,6 +878,34 @@ class TestDataprocSubmitJobOperator(DataprocJobTestBase):
project_id=GCP_PROJECT, region=GCP_LOCATION, job_id=job_id
)
+ @mock.patch(DATAPROC_PATH.format("DataprocHook"))
+ def test_on_kill_after_execution_timeout(self, mock_hook):
+ job = {}
+ job_id = "job_id"
+ mock_hook.return_value.wait_for_job.side_effect = AirflowTaskTimeout()
+ mock_hook.return_value.submit_job.return_value.reference.job_id = job_id
+
+ op = DataprocSubmitJobOperator(
+ task_id=TASK_ID,
+ region=GCP_LOCATION,
+ project_id=GCP_PROJECT,
+ job=job,
+ gcp_conn_id=GCP_CONN_ID,
+ retry=RETRY,
+ timeout=TIMEOUT,
+ metadata=METADATA,
+ request_id=REQUEST_ID,
+ impersonation_chain=IMPERSONATION_CHAIN,
+ cancel_on_kill=True,
+ )
+ with pytest.raises(AirflowTaskTimeout):
+ op.execute(context=self.mock_context)
+
+ op.on_kill()
+ mock_hook.return_value.cancel_job.assert_called_once_with(
+ project_id=GCP_PROJECT, region=GCP_LOCATION, job_id=job_id
+ )
+
@mock.patch(DATAPROC_PATH.format("DataprocHook"))
def test_location_deprecation_warning(self, mock_hook):
xcom_push_call = call.ti.xcom_push(execution_date=None, key='conf', value=DATAPROC_JOB_CONF_EXPECTED)