You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by po...@apache.org on 2023/10/16 10:44:53 UTC

[airflow] branch main updated: Cancel workflow in on_kill in DataprocInstantiate{Inline}WorkflowTemplateOperator (#34957)

This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 0b49f338b9 Cancel workflow in on_kill in DataprocInstantiate{Inline}WorkflowTemplateOperator (#34957)
0b49f338b9 is described below

commit 0b49f338b9e6fd3264bc0099e8879855bf6c60c9
Author: Michał Sośnicki <so...@gmail.com>
AuthorDate: Mon Oct 16 12:44:43 2023 +0200

    Cancel workflow in on_kill in DataprocInstantiate{Inline}WorkflowTemplateOperator (#34957)
    
    * Cancel operation in on_kill in DataprocInstantiateWorkflowTemplateOperator
    
    * Test on_kill method in DataprocInstantiateWorkflowTemplateOperator
---
 .../providers/google/cloud/operators/dataproc.py   | 42 ++++++++++----
 .../google/cloud/operators/test_dataproc.py        | 64 +++++++++++++++++++++-
 2 files changed, 95 insertions(+), 11 deletions(-)

diff --git a/airflow/providers/google/cloud/operators/dataproc.py b/airflow/providers/google/cloud/operators/dataproc.py
index 3d41fdd1be..48e66a831a 100644
--- a/airflow/providers/google/cloud/operators/dataproc.py
+++ b/airflow/providers/google/cloud/operators/dataproc.py
@@ -1790,6 +1790,7 @@ class DataprocInstantiateWorkflowTemplateOperator(GoogleCloudBaseOperator):
         account from the list granting this role to the originating account (templated).
     :param deferrable: Run operator in the deferrable mode.
     :param polling_interval_seconds: Time (seconds) to wait between calls to check the run status.
+    :param cancel_on_kill: Flag which indicates whether cancel the workflow, when on_kill is called
     """
 
     template_fields: Sequence[str] = ("template_id", "impersonation_chain", "request_id", "parameters")
@@ -1812,6 +1813,7 @@ class DataprocInstantiateWorkflowTemplateOperator(GoogleCloudBaseOperator):
         impersonation_chain: str | Sequence[str] | None = None,
         deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
         polling_interval_seconds: int = 10,
+        cancel_on_kill: bool = True,
         **kwargs,
     ) -> None:
         super().__init__(**kwargs)
@@ -1830,6 +1832,8 @@ class DataprocInstantiateWorkflowTemplateOperator(GoogleCloudBaseOperator):
         self.impersonation_chain = impersonation_chain
         self.deferrable = deferrable
         self.polling_interval_seconds = polling_interval_seconds
+        self.cancel_on_kill = cancel_on_kill
+        self.operation_name: str | None = None
 
     def execute(self, context: Context):
         hook = DataprocHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain)
@@ -1845,24 +1849,26 @@ class DataprocInstantiateWorkflowTemplateOperator(GoogleCloudBaseOperator):
             timeout=self.timeout,
             metadata=self.metadata,
         )
-        self.workflow_id = operation.operation.name.split("/")[-1]
+        operation_name = operation.operation.name
+        self.operation_name = operation_name
+        workflow_id = operation_name.split("/")[-1]
         project_id = self.project_id or hook.project_id
         if project_id:
             DataprocWorkflowLink.persist(
                 context=context,
                 operator=self,
-                workflow_id=self.workflow_id,
+                workflow_id=workflow_id,
                 region=self.region,
                 project_id=project_id,
             )
-        self.log.info("Template instantiated. Workflow Id : %s", self.workflow_id)
+        self.log.info("Template instantiated. Workflow Id : %s", workflow_id)
         if not self.deferrable:
             hook.wait_for_operation(timeout=self.timeout, result_retry=self.retry, operation=operation)
-            self.log.info("Workflow %s completed successfully", self.workflow_id)
+            self.log.info("Workflow %s completed successfully", workflow_id)
         else:
             self.defer(
                 trigger=DataprocWorkflowTrigger(
-                    name=operation.operation.name,
+                    name=operation_name,
                     project_id=self.project_id,
                     region=self.region,
                     gcp_conn_id=self.gcp_conn_id,
@@ -1884,6 +1890,11 @@ class DataprocInstantiateWorkflowTemplateOperator(GoogleCloudBaseOperator):
 
         self.log.info("Workflow %s completed successfully", event["operation_name"])
 
+    def on_kill(self) -> None:
+        if self.cancel_on_kill and self.operation_name:
+            hook = DataprocHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain)
+            hook.get_operations_client(region=self.region).cancel_operation(name=self.operation_name)
+
 
 class DataprocInstantiateInlineWorkflowTemplateOperator(GoogleCloudBaseOperator):
     """Instantiate a WorkflowTemplate Inline on Google Cloud Dataproc.
@@ -1926,6 +1937,7 @@ class DataprocInstantiateInlineWorkflowTemplateOperator(GoogleCloudBaseOperator)
         account from the list granting this role to the originating account (templated).
     :param deferrable: Run operator in the deferrable mode.
     :param polling_interval_seconds: Time (seconds) to wait between calls to check the run status.
+    :param cancel_on_kill: Flag which indicates whether cancel the workflow, when on_kill is called
     """
 
     template_fields: Sequence[str] = ("template", "impersonation_chain")
@@ -1946,6 +1958,7 @@ class DataprocInstantiateInlineWorkflowTemplateOperator(GoogleCloudBaseOperator)
         impersonation_chain: str | Sequence[str] | None = None,
         deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
         polling_interval_seconds: int = 10,
+        cancel_on_kill: bool = True,
         **kwargs,
     ) -> None:
         super().__init__(**kwargs)
@@ -1963,6 +1976,8 @@ class DataprocInstantiateInlineWorkflowTemplateOperator(GoogleCloudBaseOperator)
         self.impersonation_chain = impersonation_chain
         self.deferrable = deferrable
         self.polling_interval_seconds = polling_interval_seconds
+        self.cancel_on_kill = cancel_on_kill
+        self.operation_name: str | None = None
 
     def execute(self, context: Context):
         self.log.info("Instantiating Inline Template")
@@ -1977,23 +1992,25 @@ class DataprocInstantiateInlineWorkflowTemplateOperator(GoogleCloudBaseOperator)
             timeout=self.timeout,
             metadata=self.metadata,
         )
-        self.workflow_id = operation.operation.name.split("/")[-1]
+        operation_name = operation.operation.name
+        self.operation_name = operation_name
+        workflow_id = operation_name.split("/")[-1]
         if project_id:
             DataprocWorkflowLink.persist(
                 context=context,
                 operator=self,
-                workflow_id=self.workflow_id,
+                workflow_id=workflow_id,
                 region=self.region,
                 project_id=project_id,
             )
         if not self.deferrable:
-            self.log.info("Template instantiated. Workflow Id : %s", self.workflow_id)
+            self.log.info("Template instantiated. Workflow Id : %s", workflow_id)
             operation.result()
-            self.log.info("Workflow %s completed successfully", self.workflow_id)
+            self.log.info("Workflow %s completed successfully", workflow_id)
         else:
             self.defer(
                 trigger=DataprocWorkflowTrigger(
-                    name=operation.operation.name,
+                    name=operation_name,
                     project_id=self.project_id or hook.project_id,
                     region=self.region,
                     gcp_conn_id=self.gcp_conn_id,
@@ -2015,6 +2032,11 @@ class DataprocInstantiateInlineWorkflowTemplateOperator(GoogleCloudBaseOperator)
 
         self.log.info("Workflow %s completed successfully", event["operation_name"])
 
+    def on_kill(self) -> None:
+        if self.cancel_on_kill and self.operation_name:
+            hook = DataprocHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain)
+            hook.get_operations_client(region=self.region).cancel_operation(name=self.operation_name)
+
 
 class DataprocSubmitJobOperator(GoogleCloudBaseOperator):
     """Submit a job to a cluster.
diff --git a/tests/providers/google/cloud/operators/test_dataproc.py b/tests/providers/google/cloud/operators/test_dataproc.py
index 40180d4b47..02620ccb6c 100644
--- a/tests/providers/google/cloud/operators/test_dataproc.py
+++ b/tests/providers/google/cloud/operators/test_dataproc.py
@@ -1399,7 +1399,7 @@ def test_update_cluster_operator_extra_links(dag_maker, create_task_instance_of_
     assert ti.task.get_extra_links(ti, DataprocClusterLink.name) == DATAPROC_CLUSTER_LINK_EXPECTED
 
 
-class TestDataprocWorkflowTemplateInstantiateOperator:
+class TestDataprocInstantiateWorkflowTemplateOperator:
     @mock.patch(DATAPROC_PATH.format("DataprocHook"))
     def test_execute(self, mock_hook):
         version = 6
@@ -1463,6 +1463,37 @@ class TestDataprocWorkflowTemplateInstantiateOperator:
         assert isinstance(exc.value.trigger, DataprocWorkflowTrigger)
         assert exc.value.method_name == GOOGLE_DEFAULT_DEFERRABLE_METHOD_NAME
 
+    @mock.patch(DATAPROC_PATH.format("DataprocHook"))
+    def test_on_kill(self, mock_hook):
+        operation_name = "operation_name"
+        mock_hook.return_value.instantiate_workflow_template.return_value.operation.name = operation_name
+        op = DataprocInstantiateWorkflowTemplateOperator(
+            task_id=TASK_ID,
+            template_id=TEMPLATE_ID,
+            region=GCP_REGION,
+            project_id=GCP_PROJECT,
+            version=2,
+            parameters={},
+            request_id=REQUEST_ID,
+            retry=RETRY,
+            timeout=TIMEOUT,
+            metadata=METADATA,
+            gcp_conn_id=GCP_CONN_ID,
+            impersonation_chain=IMPERSONATION_CHAIN,
+            cancel_on_kill=False,
+        )
+
+        op.execute(context=mock.MagicMock())
+
+        op.on_kill()
+        mock_hook.return_value.get_operations_client.return_value.cancel_operation.assert_not_called()
+
+        op.cancel_on_kill = True
+        op.on_kill()
+        mock_hook.return_value.get_operations_client.return_value.cancel_operation.assert_called_once_with(
+            name=operation_name
+        )
+
 
 @pytest.mark.need_serialized_dag
 @mock.patch(DATAPROC_PATH.format("DataprocHook"))
@@ -1561,6 +1592,37 @@ class TestDataprocWorkflowTemplateInstantiateInlineOperator:
         assert isinstance(exc.value.trigger, DataprocWorkflowTrigger)
         assert exc.value.method_name == GOOGLE_DEFAULT_DEFERRABLE_METHOD_NAME
 
+    @mock.patch(DATAPROC_PATH.format("DataprocHook"))
+    def test_on_kill(self, mock_hook):
+        operation_name = "operation_name"
+        mock_hook.return_value.instantiate_inline_workflow_template.return_value.operation.name = (
+            operation_name
+        )
+        op = DataprocInstantiateInlineWorkflowTemplateOperator(
+            task_id=TASK_ID,
+            template={},
+            region=GCP_REGION,
+            project_id=GCP_PROJECT,
+            request_id=REQUEST_ID,
+            retry=RETRY,
+            timeout=TIMEOUT,
+            metadata=METADATA,
+            gcp_conn_id=GCP_CONN_ID,
+            impersonation_chain=IMPERSONATION_CHAIN,
+            cancel_on_kill=False,
+        )
+
+        op.execute(context=mock.MagicMock())
+
+        op.on_kill()
+        mock_hook.return_value.get_operations_client.return_value.cancel_operation.assert_not_called()
+
+        op.cancel_on_kill = True
+        op.on_kill()
+        mock_hook.return_value.get_operations_client.return_value.cancel_operation.assert_called_once_with(
+            name=operation_name
+        )
+
 
 @pytest.mark.need_serialized_dag
 @mock.patch(DATAPROC_PATH.format("DataprocHook"))