You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by po...@apache.org on 2022/09/11 23:25:20 UTC
[airflow] branch main updated: Athena and EMR operator max_retries mix-up fix (#25971)
This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new d5820a77e8 Athena and EMR operator max_retries mix-up fix (#25971)
d5820a77e8 is described below
commit d5820a77e896a1a3ceb671eddddb9c8e3bcfb649
Author: syedahsn <10...@users.noreply.github.com>
AuthorDate: Sun Sep 11 17:25:02 2022 -0600
Athena and EMR operator max_retries mix-up fix (#25971)
* Internally rename `max_tries` to `max_polling_attempts`. Add deprecation warning to inform users about naming change.
Raise Exception if values of `max_tries` does not match value of `max_polling_attempts`.
---
airflow/providers/amazon/aws/hooks/athena.py | 29 ++++++++++++++++++++----
airflow/providers/amazon/aws/hooks/emr.py | 28 +++++++++++++++++++----
airflow/providers/amazon/aws/operators/athena.py | 22 +++++++++++++++---
airflow/providers/amazon/aws/operators/emr.py | 24 ++++++++++++++++----
4 files changed, 88 insertions(+), 15 deletions(-)
diff --git a/airflow/providers/amazon/aws/hooks/athena.py b/airflow/providers/amazon/aws/hooks/athena.py
index 0c224e5fd5..f777a5b558 100644
--- a/airflow/providers/amazon/aws/hooks/athena.py
+++ b/airflow/providers/amazon/aws/hooks/athena.py
@@ -23,6 +23,7 @@ This module contains AWS Athena hook.
PageIterator
"""
+import warnings
from time import sleep
from typing import Any, Dict, Optional
@@ -188,17 +189,35 @@ class AthenaHook(AwsBaseHook):
paginator = self.get_conn().get_paginator('get_query_results')
return paginator.paginate(**result_params)
- def poll_query_status(self, query_execution_id: str, max_tries: Optional[int] = None) -> Optional[str]:
+ def poll_query_status(
+ self,
+ query_execution_id: str,
+ max_tries: Optional[int] = None,
+ max_polling_attempts: Optional[int] = None,
+ ) -> Optional[str]:
"""
Poll the status of submitted athena query until query state reaches final state.
Returns one of the final states
:param query_execution_id: Id of submitted athena query
- :param max_tries: Number of times to poll for query state before function exits
+ :param max_tries: Deprecated - Use max_polling_attempts instead
+ :param max_polling_attempts: Number of times to poll for query state before function exits
:return: str
"""
+ if max_tries:
+ warnings.warn(
+ f"Method `{self.__class__.__name__}.max_tries` is deprecated and will be removed "
+ "in a future release. Please use method `max_polling_attempts` instead.",
+ DeprecationWarning,
+ stacklevel=2,
+ )
+ if max_polling_attempts and max_polling_attempts != max_tries:
+ raise Exception("max_polling_attempts must be the same value as max_tries")
+ else:
+ max_polling_attempts = max_tries
+
try_number = 1
- final_query_state = None # Query state when query reaches final state or max_tries reached
+ final_query_state = None # Query state when query reaches final state or max_polling_attempts reached
while True:
query_state = self.check_query_status(query_execution_id)
if query_state is None:
@@ -211,7 +230,9 @@ class AthenaHook(AwsBaseHook):
break
else:
self.log.info('Trial %s: Query is still in non-terminal state - %s', try_number, query_state)
- if max_tries and try_number >= max_tries: # Break loop if max_tries reached
+ if (
+ max_polling_attempts and try_number >= max_polling_attempts
+ ): # Break loop if max_polling_attempts reached
final_query_state = query_state
break
try_number += 1
diff --git a/airflow/providers/amazon/aws/hooks/emr.py b/airflow/providers/amazon/aws/hooks/emr.py
index 48fc7684f1..8c86edf0f8 100644
--- a/airflow/providers/amazon/aws/hooks/emr.py
+++ b/airflow/providers/amazon/aws/hooks/emr.py
@@ -15,6 +15,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+import warnings
from time import sleep
from typing import Any, Callable, Dict, List, Optional, Set
@@ -321,19 +322,36 @@ class EmrContainerHook(AwsBaseHook):
return None
def poll_query_status(
- self, job_id: str, max_tries: Optional[int] = None, poll_interval: int = 30
+ self,
+ job_id: str,
+ max_tries: Optional[int] = None,
+ poll_interval: int = 30,
+ max_polling_attempts: Optional[int] = None,
) -> Optional[str]:
"""
Poll the status of submitted job run until query state reaches final state.
Returns one of the final states.
:param job_id: Id of submitted job run
- :param max_tries: Number of times to poll for query state before function exits
+ :param max_tries: Deprecated - Use max_polling_attempts instead
:param poll_interval: Time (in seconds) to wait between calls to check query status on EMR
+ :param max_polling_attempts: Number of times to poll for query state before function exits
:return: str
"""
+ if max_tries:
+ warnings.warn(
+ f"Method `{self.__class__.__name__}.max_tries` is deprecated and will be removed "
+ "in a future release. Please use method `max_polling_attempts` instead.",
+ DeprecationWarning,
+ stacklevel=2,
+ )
+ if max_polling_attempts and max_polling_attempts != max_tries:
+ raise Exception("max_polling_attempts must be the same value as max_tries")
+ else:
+ max_polling_attempts = max_tries
+
try_number = 1
- final_query_state = None # Query state when query reaches final state or max_tries reached
+ final_query_state = None # Query state when query reaches final state or max_polling_attempts reached
while True:
query_state = self.check_query_status(job_id)
@@ -345,7 +363,9 @@ class EmrContainerHook(AwsBaseHook):
break
else:
self.log.info("Try %s: Query is still in non-terminal state - %s", try_number, query_state)
- if max_tries and try_number >= max_tries: # Break loop if max_tries reached
+ if (
+ max_polling_attempts and try_number >= max_polling_attempts
+ ): # Break loop if max_polling_attempts reached
final_query_state = query_state
break
try_number += 1
diff --git a/airflow/providers/amazon/aws/operators/athena.py b/airflow/providers/amazon/aws/operators/athena.py
index 679e71000a..6fd01324fc 100644
--- a/airflow/providers/amazon/aws/operators/athena.py
+++ b/airflow/providers/amazon/aws/operators/athena.py
@@ -16,6 +16,7 @@
# specific language governing permissions and limitations
# under the License.
#
+import warnings
from typing import TYPE_CHECKING, Any, Dict, Optional, Sequence
from airflow.compat.functools import cached_property
@@ -43,7 +44,9 @@ class AthenaOperator(BaseOperator):
:param query_execution_context: Context in which query need to be run
:param result_configuration: Dict with path to store results in and config related to encryption
:param sleep_time: Time (in seconds) to wait between two consecutive calls to check query status on Athena
- :param max_tries: Number of times to poll for query state before function exits
+ :param max_tries: Deprecated - use max_polling_attempts instead.
+ :param max_polling_attempts: Number of times to poll for query state before function exits
+ To limit task execution time, use execution_timeout.
"""
ui_color = '#44b5e2'
@@ -64,6 +67,7 @@ class AthenaOperator(BaseOperator):
result_configuration: Optional[Dict[str, Any]] = None,
sleep_time: int = 30,
max_tries: Optional[int] = None,
+ max_polling_attempts: Optional[int] = None,
**kwargs: Any,
) -> None:
super().__init__(**kwargs)
@@ -76,9 +80,21 @@ class AthenaOperator(BaseOperator):
self.query_execution_context = query_execution_context or {}
self.result_configuration = result_configuration or {}
self.sleep_time = sleep_time
- self.max_tries = max_tries
+ self.max_polling_attempts = max_polling_attempts
self.query_execution_id = None # type: Optional[str]
+ if max_tries:
+ warnings.warn(
+ f"Parameter `{self.__class__.__name__}.max_tries` is deprecated and will be removed "
+ "in a future release. Please use method `max_polling_attempts` instead.",
+ DeprecationWarning,
+ stacklevel=2,
+ )
+ if max_polling_attempts and max_polling_attempts != max_tries:
+ raise Exception("max_polling_attempts must be the same value as max_tries")
+ else:
+ self.max_polling_attempts = max_tries
+
@cached_property
def hook(self) -> AthenaHook:
"""Create and return an AthenaHook."""
@@ -95,7 +111,7 @@ class AthenaOperator(BaseOperator):
self.client_request_token,
self.workgroup,
)
- query_status = self.hook.poll_query_status(self.query_execution_id, self.max_tries)
+ query_status = self.hook.poll_query_status(self.query_execution_id, self.max_polling_attempts)
if query_status in AthenaHook.FAILURE_STATES:
error_message = self.hook.get_state_change_reason(self.query_execution_id)
diff --git a/airflow/providers/amazon/aws/operators/emr.py b/airflow/providers/amazon/aws/operators/emr.py
index 3b7bda0cd6..f06f3834e9 100644
--- a/airflow/providers/amazon/aws/operators/emr.py
+++ b/airflow/providers/amazon/aws/operators/emr.py
@@ -16,6 +16,7 @@
# specific language governing permissions and limitations
# under the License.
import ast
+import warnings
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Union
from uuid import uuid4
@@ -195,7 +196,8 @@ class EmrContainerOperator(BaseOperator):
:param aws_conn_id: The Airflow connection used for AWS credentials.
:param wait_for_completion: Whether or not to wait in the operator for the job to complete.
:param poll_interval: Time (in seconds) to wait between two consecutive calls to check query status on EMR
- :param max_tries: Maximum number of times to wait for the job run to finish.
+ :param max_tries: Deprecated - use max_polling_attempts instead.
+ :param max_polling_attempts: Maximum number of times to wait for the job run to finish.
Defaults to None, which will poll until the job is *not* in a pending, submitted, or running state.
:param tags: The tags assigned to job runs.
Defaults to None
@@ -225,6 +227,7 @@ class EmrContainerOperator(BaseOperator):
poll_interval: int = 30,
max_tries: Optional[int] = None,
tags: Optional[dict] = None,
+ max_polling_attempts: Optional[int] = None,
**kwargs: Any,
) -> None:
super().__init__(**kwargs)
@@ -238,10 +241,22 @@ class EmrContainerOperator(BaseOperator):
self.client_request_token = client_request_token or str(uuid4())
self.wait_for_completion = wait_for_completion
self.poll_interval = poll_interval
- self.max_tries = max_tries
+ self.max_polling_attempts = max_polling_attempts
self.tags = tags
self.job_id: Optional[str] = None
+ if max_tries:
+ warnings.warn(
+ f"Parameter `{self.__class__.__name__}.max_tries` is deprecated and will be removed "
+ "in a future release. Please use method `max_polling_attempts` instead.",
+ DeprecationWarning,
+ stacklevel=2,
+ )
+ if max_polling_attempts and max_polling_attempts != max_tries:
+ raise Exception("max_polling_attempts must be the same value as max_tries")
+ else:
+ self.max_polling_attempts = max_tries
+
@cached_property
def hook(self) -> EmrContainerHook:
"""Create and return an EmrContainerHook."""
@@ -262,7 +277,9 @@ class EmrContainerOperator(BaseOperator):
self.tags,
)
if self.wait_for_completion:
- query_status = self.hook.poll_query_status(self.job_id, self.max_tries, self.poll_interval)
+ query_status = self.hook.poll_query_status(
+ self.job_id, self.max_polling_attempts, self.poll_interval
+ )
if query_status in EmrContainerHook.FAILURE_STATES:
error_message = self.hook.get_job_failure_reason(self.job_id)
@@ -352,7 +369,6 @@ class EmrCreateJobFlowOperator(BaseOperator):
self.log.info(
'Creating JobFlow using aws-conn-id: %s, emr-conn-id: %s', self.aws_conn_id, self.emr_conn_id
)
-
if isinstance(self.job_flow_overrides, str):
job_flow_overrides: Dict[str, Any] = ast.literal_eval(self.job_flow_overrides)
self.job_flow_overrides = job_flow_overrides