You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by po...@apache.org on 2023/10/16 10:44:53 UTC
[airflow] branch main updated: Cancel workflow in on_kill in DataprocInstantiate{Inline}WorkflowTemplateOperator (#34957)
This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 0b49f338b9 Cancel workflow in on_kill in DataprocInstantiate{Inline}WorkflowTemplateOperator (#34957)
0b49f338b9 is described below
commit 0b49f338b9e6fd3264bc0099e8879855bf6c60c9
Author: Michał Sośnicki <so...@gmail.com>
AuthorDate: Mon Oct 16 12:44:43 2023 +0200
Cancel workflow in on_kill in DataprocInstantiate{Inline}WorkflowTemplateOperator (#34957)
* Cancel operation in on_kill in DataprocInstantiateWorkflowTemplateOperator
* Test on_kill method in DataprocInstantiateWorkflowTemplateOperator
---
.../providers/google/cloud/operators/dataproc.py | 42 ++++++++++----
.../google/cloud/operators/test_dataproc.py | 64 +++++++++++++++++++++-
2 files changed, 95 insertions(+), 11 deletions(-)
diff --git a/airflow/providers/google/cloud/operators/dataproc.py b/airflow/providers/google/cloud/operators/dataproc.py
index 3d41fdd1be..48e66a831a 100644
--- a/airflow/providers/google/cloud/operators/dataproc.py
+++ b/airflow/providers/google/cloud/operators/dataproc.py
@@ -1790,6 +1790,7 @@ class DataprocInstantiateWorkflowTemplateOperator(GoogleCloudBaseOperator):
account from the list granting this role to the originating account (templated).
:param deferrable: Run operator in the deferrable mode.
:param polling_interval_seconds: Time (seconds) to wait between calls to check the run status.
+ :param cancel_on_kill: Flag which indicates whether cancel the workflow, when on_kill is called
"""
template_fields: Sequence[str] = ("template_id", "impersonation_chain", "request_id", "parameters")
@@ -1812,6 +1813,7 @@ class DataprocInstantiateWorkflowTemplateOperator(GoogleCloudBaseOperator):
impersonation_chain: str | Sequence[str] | None = None,
deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
polling_interval_seconds: int = 10,
+ cancel_on_kill: bool = True,
**kwargs,
) -> None:
super().__init__(**kwargs)
@@ -1830,6 +1832,8 @@ class DataprocInstantiateWorkflowTemplateOperator(GoogleCloudBaseOperator):
self.impersonation_chain = impersonation_chain
self.deferrable = deferrable
self.polling_interval_seconds = polling_interval_seconds
+ self.cancel_on_kill = cancel_on_kill
+ self.operation_name: str | None = None
def execute(self, context: Context):
hook = DataprocHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain)
@@ -1845,24 +1849,26 @@ class DataprocInstantiateWorkflowTemplateOperator(GoogleCloudBaseOperator):
timeout=self.timeout,
metadata=self.metadata,
)
- self.workflow_id = operation.operation.name.split("/")[-1]
+ operation_name = operation.operation.name
+ self.operation_name = operation_name
+ workflow_id = operation_name.split("/")[-1]
project_id = self.project_id or hook.project_id
if project_id:
DataprocWorkflowLink.persist(
context=context,
operator=self,
- workflow_id=self.workflow_id,
+ workflow_id=workflow_id,
region=self.region,
project_id=project_id,
)
- self.log.info("Template instantiated. Workflow Id : %s", self.workflow_id)
+ self.log.info("Template instantiated. Workflow Id : %s", workflow_id)
if not self.deferrable:
hook.wait_for_operation(timeout=self.timeout, result_retry=self.retry, operation=operation)
- self.log.info("Workflow %s completed successfully", self.workflow_id)
+ self.log.info("Workflow %s completed successfully", workflow_id)
else:
self.defer(
trigger=DataprocWorkflowTrigger(
- name=operation.operation.name,
+ name=operation_name,
project_id=self.project_id,
region=self.region,
gcp_conn_id=self.gcp_conn_id,
@@ -1884,6 +1890,11 @@ class DataprocInstantiateWorkflowTemplateOperator(GoogleCloudBaseOperator):
self.log.info("Workflow %s completed successfully", event["operation_name"])
+ def on_kill(self) -> None:
+ if self.cancel_on_kill and self.operation_name:
+ hook = DataprocHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain)
+ hook.get_operations_client(region=self.region).cancel_operation(name=self.operation_name)
+
class DataprocInstantiateInlineWorkflowTemplateOperator(GoogleCloudBaseOperator):
"""Instantiate a WorkflowTemplate Inline on Google Cloud Dataproc.
@@ -1926,6 +1937,7 @@ class DataprocInstantiateInlineWorkflowTemplateOperator(GoogleCloudBaseOperator)
account from the list granting this role to the originating account (templated).
:param deferrable: Run operator in the deferrable mode.
:param polling_interval_seconds: Time (seconds) to wait between calls to check the run status.
+ :param cancel_on_kill: Flag which indicates whether cancel the workflow, when on_kill is called
"""
template_fields: Sequence[str] = ("template", "impersonation_chain")
@@ -1946,6 +1958,7 @@ class DataprocInstantiateInlineWorkflowTemplateOperator(GoogleCloudBaseOperator)
impersonation_chain: str | Sequence[str] | None = None,
deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
polling_interval_seconds: int = 10,
+ cancel_on_kill: bool = True,
**kwargs,
) -> None:
super().__init__(**kwargs)
@@ -1963,6 +1976,8 @@ class DataprocInstantiateInlineWorkflowTemplateOperator(GoogleCloudBaseOperator)
self.impersonation_chain = impersonation_chain
self.deferrable = deferrable
self.polling_interval_seconds = polling_interval_seconds
+ self.cancel_on_kill = cancel_on_kill
+ self.operation_name: str | None = None
def execute(self, context: Context):
self.log.info("Instantiating Inline Template")
@@ -1977,23 +1992,25 @@ class DataprocInstantiateInlineWorkflowTemplateOperator(GoogleCloudBaseOperator)
timeout=self.timeout,
metadata=self.metadata,
)
- self.workflow_id = operation.operation.name.split("/")[-1]
+ operation_name = operation.operation.name
+ self.operation_name = operation_name
+ workflow_id = operation_name.split("/")[-1]
if project_id:
DataprocWorkflowLink.persist(
context=context,
operator=self,
- workflow_id=self.workflow_id,
+ workflow_id=workflow_id,
region=self.region,
project_id=project_id,
)
if not self.deferrable:
- self.log.info("Template instantiated. Workflow Id : %s", self.workflow_id)
+ self.log.info("Template instantiated. Workflow Id : %s", workflow_id)
operation.result()
- self.log.info("Workflow %s completed successfully", self.workflow_id)
+ self.log.info("Workflow %s completed successfully", workflow_id)
else:
self.defer(
trigger=DataprocWorkflowTrigger(
- name=operation.operation.name,
+ name=operation_name,
project_id=self.project_id or hook.project_id,
region=self.region,
gcp_conn_id=self.gcp_conn_id,
@@ -2015,6 +2032,11 @@ class DataprocInstantiateInlineWorkflowTemplateOperator(GoogleCloudBaseOperator)
self.log.info("Workflow %s completed successfully", event["operation_name"])
+ def on_kill(self) -> None:
+ if self.cancel_on_kill and self.operation_name:
+ hook = DataprocHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain)
+ hook.get_operations_client(region=self.region).cancel_operation(name=self.operation_name)
+
class DataprocSubmitJobOperator(GoogleCloudBaseOperator):
"""Submit a job to a cluster.
diff --git a/tests/providers/google/cloud/operators/test_dataproc.py b/tests/providers/google/cloud/operators/test_dataproc.py
index 40180d4b47..02620ccb6c 100644
--- a/tests/providers/google/cloud/operators/test_dataproc.py
+++ b/tests/providers/google/cloud/operators/test_dataproc.py
@@ -1399,7 +1399,7 @@ def test_update_cluster_operator_extra_links(dag_maker, create_task_instance_of_
assert ti.task.get_extra_links(ti, DataprocClusterLink.name) == DATAPROC_CLUSTER_LINK_EXPECTED
-class TestDataprocWorkflowTemplateInstantiateOperator:
+class TestDataprocInstantiateWorkflowTemplateOperator:
@mock.patch(DATAPROC_PATH.format("DataprocHook"))
def test_execute(self, mock_hook):
version = 6
@@ -1463,6 +1463,37 @@ class TestDataprocWorkflowTemplateInstantiateOperator:
assert isinstance(exc.value.trigger, DataprocWorkflowTrigger)
assert exc.value.method_name == GOOGLE_DEFAULT_DEFERRABLE_METHOD_NAME
+ @mock.patch(DATAPROC_PATH.format("DataprocHook"))
+ def test_on_kill(self, mock_hook):
+ operation_name = "operation_name"
+ mock_hook.return_value.instantiate_workflow_template.return_value.operation.name = operation_name
+ op = DataprocInstantiateWorkflowTemplateOperator(
+ task_id=TASK_ID,
+ template_id=TEMPLATE_ID,
+ region=GCP_REGION,
+ project_id=GCP_PROJECT,
+ version=2,
+ parameters={},
+ request_id=REQUEST_ID,
+ retry=RETRY,
+ timeout=TIMEOUT,
+ metadata=METADATA,
+ gcp_conn_id=GCP_CONN_ID,
+ impersonation_chain=IMPERSONATION_CHAIN,
+ cancel_on_kill=False,
+ )
+
+ op.execute(context=mock.MagicMock())
+
+ op.on_kill()
+ mock_hook.return_value.get_operations_client.return_value.cancel_operation.assert_not_called()
+
+ op.cancel_on_kill = True
+ op.on_kill()
+ mock_hook.return_value.get_operations_client.return_value.cancel_operation.assert_called_once_with(
+ name=operation_name
+ )
+
@pytest.mark.need_serialized_dag
@mock.patch(DATAPROC_PATH.format("DataprocHook"))
@@ -1561,6 +1592,37 @@ class TestDataprocWorkflowTemplateInstantiateInlineOperator:
assert isinstance(exc.value.trigger, DataprocWorkflowTrigger)
assert exc.value.method_name == GOOGLE_DEFAULT_DEFERRABLE_METHOD_NAME
+ @mock.patch(DATAPROC_PATH.format("DataprocHook"))
+ def test_on_kill(self, mock_hook):
+ operation_name = "operation_name"
+ mock_hook.return_value.instantiate_inline_workflow_template.return_value.operation.name = (
+ operation_name
+ )
+ op = DataprocInstantiateInlineWorkflowTemplateOperator(
+ task_id=TASK_ID,
+ template={},
+ region=GCP_REGION,
+ project_id=GCP_PROJECT,
+ request_id=REQUEST_ID,
+ retry=RETRY,
+ timeout=TIMEOUT,
+ metadata=METADATA,
+ gcp_conn_id=GCP_CONN_ID,
+ impersonation_chain=IMPERSONATION_CHAIN,
+ cancel_on_kill=False,
+ )
+
+ op.execute(context=mock.MagicMock())
+
+ op.on_kill()
+ mock_hook.return_value.get_operations_client.return_value.cancel_operation.assert_not_called()
+
+ op.cancel_on_kill = True
+ op.on_kill()
+ mock_hook.return_value.get_operations_client.return_value.cancel_operation.assert_called_once_with(
+ name=operation_name
+ )
+
@pytest.mark.need_serialized_dag
@mock.patch(DATAPROC_PATH.format("DataprocHook"))