You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by po...@apache.org on 2022/04/14 14:38:04 UTC

[airflow] branch main updated: Fix cancel_on_kill after execution timeout for DataprocSubmitJobOperator (#22955)

This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new ea1ae1963e Fix cancel_on_kill after execution timeout for DataprocSubmitJobOperator (#22955)
ea1ae1963e is described below

commit ea1ae1963ecf1b543e4f5e8deb59d623df42d44a
Author: Krzysztof Mioduszewski <km...@gmail.com>
AuthorDate: Thu Apr 14 16:37:54 2022 +0200

    Fix cancel_on_kill after execution timeout for DataprocSubmitJobOperator (#22955)
    
    Synchronous tasks killed by execution timeout weren't canceled
    due to wrong assignment of job_id property.
---
 .../providers/google/cloud/operators/dataproc.py   | 16 ++++++------
 .../google/cloud/operators/test_dataproc.py        | 29 ++++++++++++++++++++++
 2 files changed, 38 insertions(+), 7 deletions(-)

diff --git a/airflow/providers/google/cloud/operators/dataproc.py b/airflow/providers/google/cloud/operators/dataproc.py
index 2b69dfdeb4..423f7e8fe7 100644
--- a/airflow/providers/google/cloud/operators/dataproc.py
+++ b/airflow/providers/google/cloud/operators/dataproc.py
@@ -1858,19 +1858,21 @@ class DataprocSubmitJobOperator(BaseOperator):
             timeout=self.timeout,
             metadata=self.metadata,
         )
-        job_id = job_object.reference.job_id
-        self.log.info('Job %s submitted successfully.', job_id)
+        new_job_id: str = job_object.reference.job_id
+        self.log.info('Job %s submitted successfully.', new_job_id)
         # Save data required by extra links no matter what the job status will be
-        DataprocLink.persist(context=context, task_instance=self, url=DATAPROC_JOB_LOG_LINK, resource=job_id)
+        DataprocLink.persist(
+            context=context, task_instance=self, url=DATAPROC_JOB_LOG_LINK, resource=new_job_id
+        )
 
+        self.job_id = new_job_id
         if not self.asynchronous:
-            self.log.info('Waiting for job %s to complete', job_id)
+            self.log.info('Waiting for job %s to complete', new_job_id)
             self.hook.wait_for_job(
-                job_id=job_id, region=self.region, project_id=self.project_id, timeout=self.wait_timeout
+                job_id=new_job_id, region=self.region, project_id=self.project_id, timeout=self.wait_timeout
             )
-            self.log.info('Job %s completed successfully.', job_id)
+            self.log.info('Job %s completed successfully.', new_job_id)
 
-        self.job_id = job_id
         return self.job_id
 
     def on_kill(self):
diff --git a/tests/providers/google/cloud/operators/test_dataproc.py b/tests/providers/google/cloud/operators/test_dataproc.py
index 527765135c..ec2392f168 100644
--- a/tests/providers/google/cloud/operators/test_dataproc.py
+++ b/tests/providers/google/cloud/operators/test_dataproc.py
@@ -25,6 +25,7 @@ from google.api_core.exceptions import AlreadyExists, NotFound
 from google.api_core.retry import Retry
 
 from airflow import AirflowException
+from airflow.exceptions import AirflowTaskTimeout
 from airflow.models import DAG, DagBag
 from airflow.providers.google.cloud.operators.dataproc import (
     DATAPROC_CLUSTER_LINK,
@@ -877,6 +878,34 @@ class TestDataprocSubmitJobOperator(DataprocJobTestBase):
             project_id=GCP_PROJECT, region=GCP_LOCATION, job_id=job_id
         )
 
+    @mock.patch(DATAPROC_PATH.format("DataprocHook"))
+    def test_on_kill_after_execution_timeout(self, mock_hook):
+        job = {}
+        job_id = "job_id"
+        mock_hook.return_value.wait_for_job.side_effect = AirflowTaskTimeout()
+        mock_hook.return_value.submit_job.return_value.reference.job_id = job_id
+
+        op = DataprocSubmitJobOperator(
+            task_id=TASK_ID,
+            region=GCP_LOCATION,
+            project_id=GCP_PROJECT,
+            job=job,
+            gcp_conn_id=GCP_CONN_ID,
+            retry=RETRY,
+            timeout=TIMEOUT,
+            metadata=METADATA,
+            request_id=REQUEST_ID,
+            impersonation_chain=IMPERSONATION_CHAIN,
+            cancel_on_kill=True,
+        )
+        with pytest.raises(AirflowTaskTimeout):
+            op.execute(context=self.mock_context)
+
+        op.on_kill()
+        mock_hook.return_value.cancel_job.assert_called_once_with(
+            project_id=GCP_PROJECT, region=GCP_LOCATION, job_id=job_id
+        )
+
     @mock.patch(DATAPROC_PATH.format("DataprocHook"))
     def test_location_deprecation_warning(self, mock_hook):
         xcom_push_call = call.ti.xcom_push(execution_date=None, key='conf', value=DATAPROC_JOB_CONF_EXPECTED)