You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by po...@apache.org on 2022/10/10 03:47:46 UTC

[airflow] branch main updated: Fix delay in Dataproc CreateBatch operator (#26126)

This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 6f0b600293 Fix delay in Dataproc CreateBatch operator (#26126)
6f0b600293 is described below

commit 6f0b600293ad53c1c4e3036b0572ca29b98b2fb2
Author: VladaZakharova <80...@users.noreply.github.com>
AuthorDate: Mon Oct 10 05:47:37 2022 +0200

    Fix delay in Dataproc CreateBatch operator (#26126)
---
 airflow/providers/google/cloud/hooks/dataproc.py   |  9 ++++--
 .../providers/google/cloud/operators/dataproc.py   |  8 +++++-
 .../google/cloud/operators/test_dataproc.py        | 31 +++++++++++++++++++++
 .../cloud/dataproc/example_dataproc_batch.py       | 32 ++++++++++++++++++++--
 4 files changed, 75 insertions(+), 5 deletions(-)

diff --git a/airflow/providers/google/cloud/hooks/dataproc.py b/airflow/providers/google/cloud/hooks/dataproc.py
index 4ea0be8f06..c1ed1ac174 100644
--- a/airflow/providers/google/cloud/hooks/dataproc.py
+++ b/airflow/providers/google/cloud/hooks/dataproc.py
@@ -253,10 +253,15 @@ class DataprocHook(GoogleBaseHook):
             credentials=self.get_credentials(), client_info=CLIENT_INFO, client_options=client_options
         )
 
-    def wait_for_operation(self, operation: Operation, timeout: float | None = None):
+    def wait_for_operation(
+        self,
+        operation: Operation,
+        timeout: float | None = None,
+        result_retry: Retry | _MethodDefault = DEFAULT,
+    ):
         """Waits for long-lasting operation to complete."""
         try:
-            return operation.result(timeout=timeout)
+            return operation.result(timeout=timeout, retry=result_retry)
         except Exception:
             error = operation.exception(timeout=timeout)
             raise AirflowException(error)
diff --git a/airflow/providers/google/cloud/operators/dataproc.py b/airflow/providers/google/cloud/operators/dataproc.py
index 4304aca3eb..73666b95ea 100644
--- a/airflow/providers/google/cloud/operators/dataproc.py
+++ b/airflow/providers/google/cloud/operators/dataproc.py
@@ -2038,6 +2038,8 @@ class DataprocCreateBatchOperator(BaseOperator):
         the first ``google.longrunning.Operation`` created and stored in the backend is returned.
     :param retry: A retry object used to retry requests. If ``None`` is specified, requests will not be
         retried.
+    :param result_retry: Result retry object used to retry requests. Is used to decrease delay between
+        executing chained tasks in a DAG by specifying exact amount of seconds for executing.
     :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if
         ``retry`` is specified, the timeout applies to each individual attempt.
     :param metadata: Additional metadata that is provided to the method.
@@ -2074,6 +2076,7 @@ class DataprocCreateBatchOperator(BaseOperator):
         metadata: Sequence[tuple[str, str]] = (),
         gcp_conn_id: str = "google_cloud_default",
         impersonation_chain: str | Sequence[str] | None = None,
+        result_retry: Retry | _MethodDefault = DEFAULT,
         **kwargs,
     ):
         super().__init__(**kwargs)
@@ -2083,6 +2086,7 @@ class DataprocCreateBatchOperator(BaseOperator):
         self.batch_id = batch_id
         self.request_id = request_id
         self.retry = retry
+        self.result_retry = result_retry
         self.timeout = timeout
         self.metadata = metadata
         self.gcp_conn_id = gcp_conn_id
@@ -2107,7 +2111,9 @@ class DataprocCreateBatchOperator(BaseOperator):
             )
             if self.operation is None:
                 raise RuntimeError("The operation should be set here!")
-            result = hook.wait_for_operation(timeout=self.timeout, operation=self.operation)
+            result = hook.wait_for_operation(
+                timeout=self.timeout, result_retry=self.result_retry, operation=self.operation
+            )
             self.log.info("Batch %s created", self.batch_id)
         except AlreadyExists:
             self.log.info("Batch with given id already exists")
diff --git a/tests/providers/google/cloud/operators/test_dataproc.py b/tests/providers/google/cloud/operators/test_dataproc.py
index 629ff85edc..f9449e2dd7 100644
--- a/tests/providers/google/cloud/operators/test_dataproc.py
+++ b/tests/providers/google/cloud/operators/test_dataproc.py
@@ -193,6 +193,7 @@ UPDATE_MASK = {
 
 TIMEOUT = 120
 RETRY = mock.MagicMock(Retry)
+RESULT_RETRY = mock.MagicMock(Retry)
 METADATA = [("key", "value")]
 REQUEST_ID = "request_id_uuid"
 
@@ -1706,6 +1707,36 @@ class TestDataprocCreateBatchOperator:
             metadata=METADATA,
         )
 
+    @mock.patch(DATAPROC_PATH.format("Batch.to_dict"))
+    @mock.patch(DATAPROC_PATH.format("DataprocHook"))
+    def test_execute_with_result_retry(self, mock_hook, to_dict_mock):
+        op = DataprocCreateBatchOperator(
+            task_id=TASK_ID,
+            gcp_conn_id=GCP_CONN_ID,
+            impersonation_chain=IMPERSONATION_CHAIN,
+            region=GCP_LOCATION,
+            project_id=GCP_PROJECT,
+            batch=BATCH,
+            batch_id=BATCH_ID,
+            request_id=REQUEST_ID,
+            retry=RETRY,
+            result_retry=RESULT_RETRY,
+            timeout=TIMEOUT,
+            metadata=METADATA,
+        )
+        op.execute(context=MagicMock())
+        mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN)
+        mock_hook.return_value.create_batch.assert_called_once_with(
+            region=GCP_LOCATION,
+            project_id=GCP_PROJECT,
+            batch=BATCH,
+            batch_id=BATCH_ID,
+            request_id=REQUEST_ID,
+            retry=RETRY,
+            timeout=TIMEOUT,
+            metadata=METADATA,
+        )
+
     @mock.patch(DATAPROC_PATH.format("Batch.to_dict"))
     @mock.patch(DATAPROC_PATH.format("DataprocHook"))
     def test_execute_batch_failed(self, mock_hook, to_dict_mock):
diff --git a/tests/system/providers/google/cloud/dataproc/example_dataproc_batch.py b/tests/system/providers/google/cloud/dataproc/example_dataproc_batch.py
index 0a7b42aaf5..25e276c1ba 100644
--- a/tests/system/providers/google/cloud/dataproc/example_dataproc_batch.py
+++ b/tests/system/providers/google/cloud/dataproc/example_dataproc_batch.py
@@ -22,6 +22,8 @@ from __future__ import annotations
 import os
 from datetime import datetime
 
+from google.api_core.retry import Retry
+
 from airflow import models
 from airflow.providers.google.cloud.operators.dataproc import (
     DataprocCreateBatchOperator,
@@ -36,6 +38,7 @@ DAG_ID = "dataproc_batch"
 PROJECT_ID = os.environ.get("SYSTEM_TESTS_GCP_PROJECT", "")
 REGION = "europe-west1"
 BATCH_ID = f"test-batch-id-{ENV_ID}"
+BATCH_ID_2 = f"test-batch-id-{ENV_ID}-2"
 BATCH_CONFIG = {
     "spark_batch": {
         "jar_file_uris": ["file:///usr/lib/spark/examples/jars/spark-examples.jar"],
@@ -58,7 +61,15 @@ with models.DAG(
         region=REGION,
         batch=BATCH_CONFIG,
         batch_id=BATCH_ID,
-        timeout=5.0,
+    )
+
+    create_batch_2 = DataprocCreateBatchOperator(
+        task_id="create_batch_2",
+        project_id=PROJECT_ID,
+        region=REGION,
+        batch=BATCH_CONFIG,
+        batch_id=BATCH_ID_2,
+        result_retry=Retry(maximum=10.0, initial=10.0, multiplier=1.0),
     )
     # [END how_to_cloud_dataproc_create_batch_operator]
 
@@ -66,6 +77,10 @@ with models.DAG(
     get_batch = DataprocGetBatchOperator(
         task_id="get_batch", project_id=PROJECT_ID, region=REGION, batch_id=BATCH_ID
     )
+
+    get_batch_2 = DataprocGetBatchOperator(
+        task_id="get_batch_2", project_id=PROJECT_ID, region=REGION, batch_id=BATCH_ID_2
+    )
     # [END how_to_cloud_dataproc_get_batch_operator]
 
     # [START how_to_cloud_dataproc_list_batches_operator]
@@ -80,10 +95,23 @@ with models.DAG(
     delete_batch = DataprocDeleteBatchOperator(
         task_id="delete_batch", project_id=PROJECT_ID, region=REGION, batch_id=BATCH_ID
     )
+    delete_batch.trigger_rule = TriggerRule.ALL_DONE
+
+    delete_batch_2 = DataprocDeleteBatchOperator(
+        task_id="delete_batch_2", project_id=PROJECT_ID, region=REGION, batch_id=BATCH_ID_2
+    )
     # [END how_to_cloud_dataproc_delete_batch_operator]
     delete_batch.trigger_rule = TriggerRule.ALL_DONE
 
-    create_batch >> get_batch >> list_batches >> delete_batch
+    (
+        create_batch
+        >> create_batch_2
+        >> get_batch
+        >> get_batch_2
+        >> list_batches
+        >> delete_batch
+        >> delete_batch_2
+    )
 
     from tests.system.utils.watcher import watcher