You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by po...@apache.org on 2022/09/11 23:25:20 UTC

[airflow] branch main updated: Athena and EMR operator max_retries mix-up fix (#25971)

This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new d5820a77e8 Athena and EMR operator max_retries mix-up fix (#25971)
d5820a77e8 is described below

commit d5820a77e896a1a3ceb671eddddb9c8e3bcfb649
Author: syedahsn <10...@users.noreply.github.com>
AuthorDate: Sun Sep 11 17:25:02 2022 -0600

    Athena and EMR operator max_retries mix-up fix (#25971)
    
    * Internally rename `max_tries` to `max_polling_attempts`. Add deprecation warning to inform users about naming change.
    Raise Exception if values of `max_tries` does not match value of `max_polling_attempts`.
---
 airflow/providers/amazon/aws/hooks/athena.py     | 29 ++++++++++++++++++++----
 airflow/providers/amazon/aws/hooks/emr.py        | 28 +++++++++++++++++++----
 airflow/providers/amazon/aws/operators/athena.py | 22 +++++++++++++++---
 airflow/providers/amazon/aws/operators/emr.py    | 24 ++++++++++++++++----
 4 files changed, 88 insertions(+), 15 deletions(-)

diff --git a/airflow/providers/amazon/aws/hooks/athena.py b/airflow/providers/amazon/aws/hooks/athena.py
index 0c224e5fd5..f777a5b558 100644
--- a/airflow/providers/amazon/aws/hooks/athena.py
+++ b/airflow/providers/amazon/aws/hooks/athena.py
@@ -23,6 +23,7 @@ This module contains AWS Athena hook.
 
     PageIterator
 """
+import warnings
 from time import sleep
 from typing import Any, Dict, Optional
 
@@ -188,17 +189,35 @@ class AthenaHook(AwsBaseHook):
         paginator = self.get_conn().get_paginator('get_query_results')
         return paginator.paginate(**result_params)
 
-    def poll_query_status(self, query_execution_id: str, max_tries: Optional[int] = None) -> Optional[str]:
+    def poll_query_status(
+        self,
+        query_execution_id: str,
+        max_tries: Optional[int] = None,
+        max_polling_attempts: Optional[int] = None,
+    ) -> Optional[str]:
         """
         Poll the status of submitted athena query until query state reaches final state.
         Returns one of the final states
 
         :param query_execution_id: Id of submitted athena query
-        :param max_tries: Number of times to poll for query state before function exits
+        :param max_tries: Deprecated - Use max_polling_attempts instead
+        :param max_polling_attempts: Number of times to poll for query state before function exits
         :return: str
         """
+        if max_tries:
+            warnings.warn(
+                f"Method `{self.__class__.__name__}.max_tries` is deprecated and will be removed "
+                "in a future release.  Please use method `max_polling_attempts` instead.",
+                DeprecationWarning,
+                stacklevel=2,
+            )
+            if max_polling_attempts and max_polling_attempts != max_tries:
+                raise Exception("max_polling_attempts must be the same value as max_tries")
+            else:
+                max_polling_attempts = max_tries
+
         try_number = 1
-        final_query_state = None  # Query state when query reaches final state or max_tries reached
+        final_query_state = None  # Query state when query reaches final state or max_polling_attempts reached
         while True:
             query_state = self.check_query_status(query_execution_id)
             if query_state is None:
@@ -211,7 +230,9 @@ class AthenaHook(AwsBaseHook):
                 break
             else:
                 self.log.info('Trial %s: Query is still in non-terminal state - %s', try_number, query_state)
-            if max_tries and try_number >= max_tries:  # Break loop if max_tries reached
+            if (
+                max_polling_attempts and try_number >= max_polling_attempts
+            ):  # Break loop if max_polling_attempts reached
                 final_query_state = query_state
                 break
             try_number += 1
diff --git a/airflow/providers/amazon/aws/hooks/emr.py b/airflow/providers/amazon/aws/hooks/emr.py
index 48fc7684f1..8c86edf0f8 100644
--- a/airflow/providers/amazon/aws/hooks/emr.py
+++ b/airflow/providers/amazon/aws/hooks/emr.py
@@ -15,6 +15,7 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
+import warnings
 from time import sleep
 from typing import Any, Callable, Dict, List, Optional, Set
 
@@ -321,19 +322,36 @@ class EmrContainerHook(AwsBaseHook):
             return None
 
     def poll_query_status(
-        self, job_id: str, max_tries: Optional[int] = None, poll_interval: int = 30
+        self,
+        job_id: str,
+        max_tries: Optional[int] = None,
+        poll_interval: int = 30,
+        max_polling_attempts: Optional[int] = None,
     ) -> Optional[str]:
         """
         Poll the status of submitted job run until query state reaches final state.
         Returns one of the final states.
 
         :param job_id: Id of submitted job run
-        :param max_tries: Number of times to poll for query state before function exits
+        :param max_tries: Deprecated - Use max_polling_attempts instead
         :param poll_interval: Time (in seconds) to wait between calls to check query status on EMR
+        :param max_polling_attempts: Number of times to poll for query state before function exits
         :return: str
         """
+        if max_tries:
+            warnings.warn(
+                f"Method `{self.__class__.__name__}.max_tries` is deprecated and will be removed "
+                "in a future release.  Please use method `max_polling_attempts` instead.",
+                DeprecationWarning,
+                stacklevel=2,
+            )
+            if max_polling_attempts and max_polling_attempts != max_tries:
+                raise Exception("max_polling_attempts must be the same value as max_tries")
+            else:
+                max_polling_attempts = max_tries
+
         try_number = 1
-        final_query_state = None  # Query state when query reaches final state or max_tries reached
+        final_query_state = None  # Query state when query reaches final state or max_polling_attempts reached
 
         while True:
             query_state = self.check_query_status(job_id)
@@ -345,7 +363,9 @@ class EmrContainerHook(AwsBaseHook):
                 break
             else:
                 self.log.info("Try %s: Query is still in non-terminal state - %s", try_number, query_state)
-            if max_tries and try_number >= max_tries:  # Break loop if max_tries reached
+            if (
+                max_polling_attempts and try_number >= max_polling_attempts
+            ):  # Break loop if max_polling_attempts reached
                 final_query_state = query_state
                 break
             try_number += 1
diff --git a/airflow/providers/amazon/aws/operators/athena.py b/airflow/providers/amazon/aws/operators/athena.py
index 679e71000a..6fd01324fc 100644
--- a/airflow/providers/amazon/aws/operators/athena.py
+++ b/airflow/providers/amazon/aws/operators/athena.py
@@ -16,6 +16,7 @@
 # specific language governing permissions and limitations
 # under the License.
 #
+import warnings
 from typing import TYPE_CHECKING, Any, Dict, Optional, Sequence
 
 from airflow.compat.functools import cached_property
@@ -43,7 +44,9 @@ class AthenaOperator(BaseOperator):
     :param query_execution_context: Context in which query need to be run
     :param result_configuration: Dict with path to store results in and config related to encryption
     :param sleep_time: Time (in seconds) to wait between two consecutive calls to check query status on Athena
-    :param max_tries: Number of times to poll for query state before function exits
+    :param max_tries: Deprecated - use max_polling_attempts instead.
+    :param max_polling_attempts: Number of times to poll for query state before function exits
+        To limit task execution time, use execution_timeout.
     """
 
     ui_color = '#44b5e2'
@@ -64,6 +67,7 @@ class AthenaOperator(BaseOperator):
         result_configuration: Optional[Dict[str, Any]] = None,
         sleep_time: int = 30,
         max_tries: Optional[int] = None,
+        max_polling_attempts: Optional[int] = None,
         **kwargs: Any,
     ) -> None:
         super().__init__(**kwargs)
@@ -76,9 +80,21 @@ class AthenaOperator(BaseOperator):
         self.query_execution_context = query_execution_context or {}
         self.result_configuration = result_configuration or {}
         self.sleep_time = sleep_time
-        self.max_tries = max_tries
+        self.max_polling_attempts = max_polling_attempts
         self.query_execution_id = None  # type: Optional[str]
 
+        if max_tries:
+            warnings.warn(
+                f"Parameter `{self.__class__.__name__}.max_tries` is deprecated and will be removed "
+                "in a future release.  Please use method `max_polling_attempts` instead.",
+                DeprecationWarning,
+                stacklevel=2,
+            )
+            if max_polling_attempts and max_polling_attempts != max_tries:
+                raise Exception("max_polling_attempts must be the same value as max_tries")
+            else:
+                self.max_polling_attempts = max_tries
+
     @cached_property
     def hook(self) -> AthenaHook:
         """Create and return an AthenaHook."""
@@ -95,7 +111,7 @@ class AthenaOperator(BaseOperator):
             self.client_request_token,
             self.workgroup,
         )
-        query_status = self.hook.poll_query_status(self.query_execution_id, self.max_tries)
+        query_status = self.hook.poll_query_status(self.query_execution_id, self.max_polling_attempts)
 
         if query_status in AthenaHook.FAILURE_STATES:
             error_message = self.hook.get_state_change_reason(self.query_execution_id)
diff --git a/airflow/providers/amazon/aws/operators/emr.py b/airflow/providers/amazon/aws/operators/emr.py
index 3b7bda0cd6..f06f3834e9 100644
--- a/airflow/providers/amazon/aws/operators/emr.py
+++ b/airflow/providers/amazon/aws/operators/emr.py
@@ -16,6 +16,7 @@
 # specific language governing permissions and limitations
 # under the License.
 import ast
+import warnings
 from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Union
 from uuid import uuid4
 
@@ -195,7 +196,8 @@ class EmrContainerOperator(BaseOperator):
     :param aws_conn_id: The Airflow connection used for AWS credentials.
     :param wait_for_completion: Whether or not to wait in the operator for the job to complete.
     :param poll_interval: Time (in seconds) to wait between two consecutive calls to check query status on EMR
-    :param max_tries: Maximum number of times to wait for the job run to finish.
+    :param max_tries: Deprecated - use max_polling_attempts instead.
+    :param max_polling_attempts: Maximum number of times to wait for the job run to finish.
         Defaults to None, which will poll until the job is *not* in a pending, submitted, or running state.
     :param tags: The tags assigned to job runs.
         Defaults to None
@@ -225,6 +227,7 @@ class EmrContainerOperator(BaseOperator):
         poll_interval: int = 30,
         max_tries: Optional[int] = None,
         tags: Optional[dict] = None,
+        max_polling_attempts: Optional[int] = None,
         **kwargs: Any,
     ) -> None:
         super().__init__(**kwargs)
@@ -238,10 +241,22 @@ class EmrContainerOperator(BaseOperator):
         self.client_request_token = client_request_token or str(uuid4())
         self.wait_for_completion = wait_for_completion
         self.poll_interval = poll_interval
-        self.max_tries = max_tries
+        self.max_polling_attempts = max_polling_attempts
         self.tags = tags
         self.job_id: Optional[str] = None
 
+        if max_tries:
+            warnings.warn(
+                f"Parameter `{self.__class__.__name__}.max_tries` is deprecated and will be removed "
+                "in a future release.  Please use method `max_polling_attempts` instead.",
+                DeprecationWarning,
+                stacklevel=2,
+            )
+            if max_polling_attempts and max_polling_attempts != max_tries:
+                raise Exception("max_polling_attempts must be the same value as max_tries")
+            else:
+                self.max_polling_attempts = max_tries
+
     @cached_property
     def hook(self) -> EmrContainerHook:
         """Create and return an EmrContainerHook."""
@@ -262,7 +277,9 @@ class EmrContainerOperator(BaseOperator):
             self.tags,
         )
         if self.wait_for_completion:
-            query_status = self.hook.poll_query_status(self.job_id, self.max_tries, self.poll_interval)
+            query_status = self.hook.poll_query_status(
+                self.job_id, self.max_polling_attempts, self.poll_interval
+            )
 
             if query_status in EmrContainerHook.FAILURE_STATES:
                 error_message = self.hook.get_job_failure_reason(self.job_id)
@@ -352,7 +369,6 @@ class EmrCreateJobFlowOperator(BaseOperator):
         self.log.info(
             'Creating JobFlow using aws-conn-id: %s, emr-conn-id: %s', self.aws_conn_id, self.emr_conn_id
         )
-
         if isinstance(self.job_flow_overrides, str):
             job_flow_overrides: Dict[str, Any] = ast.literal_eval(self.job_flow_overrides)
             self.job_flow_overrides = job_flow_overrides