You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by po...@apache.org on 2022/12/05 02:41:27 UTC

[airflow] branch main updated: add some important log in aws athena hook (#27917)

This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 8cf6dca36b add some important log in aws athena hook (#27917)
8cf6dca36b is described below

commit 8cf6dca36b0cfc16763cb1d4c96ab04d1fe5ec14
Author: Bob Du <i...@bobdu.cc>
AuthorDate: Mon Dec 5 10:41:18 2022 +0800

    add some important log in aws athena hook (#27917)
---
 airflow/providers/amazon/aws/hooks/athena.py     | 65 ++++++++++++++++++------
 airflow/providers/amazon/aws/operators/athena.py | 13 +++--
 tests/providers/amazon/aws/hooks/test_athena.py  | 21 ++++++++
 3 files changed, 80 insertions(+), 19 deletions(-)

diff --git a/airflow/providers/amazon/aws/hooks/athena.py b/airflow/providers/amazon/aws/hooks/athena.py
index 3c70898a34..0f80357fea 100644
--- a/airflow/providers/amazon/aws/hooks/athena.py
+++ b/airflow/providers/amazon/aws/hooks/athena.py
@@ -44,6 +44,8 @@ class AthenaHook(AwsBaseHook):
         :class:`~airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook`
 
     :param sleep_time: Time (in seconds) to wait between two consecutive calls to check query status on Athena
+    :param log_query: Whether to log athena query and other execution params when it's executed.
+        Defaults to *True*.
     """
 
     INTERMEDIATE_STATES = (
@@ -61,9 +63,10 @@ class AthenaHook(AwsBaseHook):
         "CANCELLED",
     )
 
-    def __init__(self, *args: Any, sleep_time: int = 30, **kwargs: Any) -> None:
+    def __init__(self, *args: Any, sleep_time: int = 30, log_query: bool = True, **kwargs: Any) -> None:
         super().__init__(client_type="athena", *args, **kwargs)  # type: ignore
         self.sleep_time = sleep_time
+        self.log_query = log_query
 
     def run_query(
         self,
@@ -91,8 +94,12 @@ class AthenaHook(AwsBaseHook):
         }
         if client_request_token:
             params["ClientRequestToken"] = client_request_token
+        if self.log_query:
+            self.log.info("Running Query with params: %s", params)
         response = self.get_conn().start_query_execution(**params)
-        return response["QueryExecutionId"]
+        query_execution_id = response["QueryExecutionId"]
+        self.log.info("Query execution id: %s", query_execution_id)
+        return query_execution_id
 
     def check_query_status(self, query_execution_id: str) -> str | None:
         """
@@ -105,8 +112,10 @@ class AthenaHook(AwsBaseHook):
         state = None
         try:
             state = response["QueryExecution"]["Status"]["State"]
-        except Exception as ex:
-            self.log.error("Exception while getting query state %s", ex)
+        except Exception:
+            self.log.exception(
+                "Exception while getting query state. Query execution id: %s", query_execution_id
+            )
         finally:
             # The error is being absorbed here and is being handled by the caller.
             # The error is being absorbed to implement retries.
@@ -123,8 +132,11 @@ class AthenaHook(AwsBaseHook):
         reason = None
         try:
             reason = response["QueryExecution"]["Status"]["StateChangeReason"]
-        except Exception as ex:
-            self.log.error("Exception while getting query state change reason: %s", ex)
+        except Exception:
+            self.log.exception(
+                "Exception while getting query state change reason. Query execution id: %s",
+                query_execution_id,
+            )
         finally:
             # The error is being absorbed here and is being handled by the caller.
             # The error is being absorbed to implement retries.
@@ -144,10 +156,14 @@ class AthenaHook(AwsBaseHook):
         """
         query_state = self.check_query_status(query_execution_id)
         if query_state is None:
-            self.log.error("Invalid Query state")
+            self.log.error("Invalid Query state. Query execution id: %s", query_execution_id)
             return None
         elif query_state in self.INTERMEDIATE_STATES or query_state in self.FAILURE_STATES:
-            self.log.error('Query is in "%s" state. Cannot fetch results', query_state)
+            self.log.error(
+                'Query is in "%s" state. Cannot fetch results. Query execution id: %s',
+                query_state,
+                query_execution_id,
+            )
             return None
         result_params = {"QueryExecutionId": query_execution_id, "MaxResults": max_results}
         if next_token_id:
@@ -174,10 +190,14 @@ class AthenaHook(AwsBaseHook):
         """
         query_state = self.check_query_status(query_execution_id)
         if query_state is None:
-            self.log.error("Invalid Query state (null)")
+            self.log.error("Invalid Query state (null). Query execution id: %s", query_execution_id)
             return None
         if query_state in self.INTERMEDIATE_STATES or query_state in self.FAILURE_STATES:
-            self.log.error('Query is in "%s" state. Cannot fetch results', query_state)
+            self.log.error(
+                'Query is in "%s" state. Cannot fetch results, Query execution id: %s',
+                query_state,
+                query_execution_id,
+            )
             return None
         result_params = {
             "QueryExecutionId": query_execution_id,
@@ -222,15 +242,27 @@ class AthenaHook(AwsBaseHook):
         while True:
             query_state = self.check_query_status(query_execution_id)
             if query_state is None:
-                self.log.info("Trial %s: Invalid query state. Retrying again", try_number)
+                self.log.info(
+                    "Query execution id: %s, trial %s: Invalid query state. Retrying again",
+                    query_execution_id,
+                    try_number,
+                )
             elif query_state in self.TERMINAL_STATES:
                 self.log.info(
-                    "Trial %s: Query execution completed. Final state is %s}", try_number, query_state
+                    "Query execution id: %s, trial %s: Query execution completed. Final state is %s",
+                    query_execution_id,
+                    try_number,
+                    query_state,
                 )
                 final_query_state = query_state
                 break
             else:
-                self.log.info("Trial %s: Query is still in non-terminal state - %s", try_number, query_state)
+                self.log.info(
+                    "Query execution id: %s, trial %s: Query is still in non-terminal state - %s",
+                    query_execution_id,
+                    try_number,
+                    query_state,
+                )
             if (
                 max_polling_attempts and try_number >= max_polling_attempts
             ):  # Break loop if max_polling_attempts reached
@@ -256,12 +288,14 @@ class AthenaHook(AwsBaseHook):
                 try:
                     output_location = response["QueryExecution"]["ResultConfiguration"]["OutputLocation"]
                 except KeyError:
-                    self.log.error("Error retrieving OutputLocation")
+                    self.log.error(
+                        "Error retrieving OutputLocation. Query execution id: %s", query_execution_id
+                    )
                     raise
             else:
                 raise
         else:
-            raise ValueError("Invalid Query execution id")
+            raise ValueError("Invalid Query execution id. Query execution id: %s", query_execution_id)
 
         return output_location
 
@@ -272,4 +306,5 @@ class AthenaHook(AwsBaseHook):
         :param query_execution_id: Id of submitted athena query
         :return: dict
         """
+        self.log.info("Stopping Query with executionId - %s", query_execution_id)
         return self.get_conn().stop_query_execution(QueryExecutionId=query_execution_id)
diff --git a/airflow/providers/amazon/aws/operators/athena.py b/airflow/providers/amazon/aws/operators/athena.py
index 61c897a481..9a414f2c28 100644
--- a/airflow/providers/amazon/aws/operators/athena.py
+++ b/airflow/providers/amazon/aws/operators/athena.py
@@ -48,6 +48,8 @@ class AthenaOperator(BaseOperator):
     :param max_tries: Deprecated - use max_polling_attempts instead.
     :param max_polling_attempts: Number of times to poll for query state before function exits
         To limit task execution time, use execution_timeout.
+    :param log_query: Whether to log athena query and other execution params when it's executed.
+        Defaults to *True*.
     """
 
     ui_color = "#44b5e2"
@@ -69,6 +71,7 @@ class AthenaOperator(BaseOperator):
         sleep_time: int = 30,
         max_tries: int | None = None,
         max_polling_attempts: int | None = None,
+        log_query: bool = True,
         **kwargs: Any,
     ) -> None:
         super().__init__(**kwargs)
@@ -83,6 +86,7 @@ class AthenaOperator(BaseOperator):
         self.sleep_time = sleep_time
         self.max_polling_attempts = max_polling_attempts
         self.query_execution_id: str | None = None
+        self.log_query: bool = log_query
 
         if max_tries:
             warnings.warn(
@@ -99,7 +103,7 @@ class AthenaOperator(BaseOperator):
     @cached_property
     def hook(self) -> AthenaHook:
         """Create and return an AthenaHook."""
-        return AthenaHook(self.aws_conn_id, sleep_time=self.sleep_time)
+        return AthenaHook(self.aws_conn_id, sleep_time=self.sleep_time, log_query=self.log_query)
 
     def execute(self, context: Context) -> str | None:
         """Run Presto Query on Athena"""
@@ -135,13 +139,14 @@ class AthenaOperator(BaseOperator):
         """Cancel the submitted athena query"""
         if self.query_execution_id:
             self.log.info("Received a kill signal.")
-            self.log.info("Stopping Query with executionId - %s", self.query_execution_id)
             response = self.hook.stop_query(self.query_execution_id)
             http_status_code = None
             try:
                 http_status_code = response["ResponseMetadata"]["HTTPStatusCode"]
-            except Exception as ex:
-                self.log.error("Exception while cancelling query: %s", ex)
+            except Exception:
+                self.log.exception(
+                    "Exception while cancelling query. Query execution id: %s", self.query_execution_id
+                )
             finally:
                 if http_status_code is None or http_status_code != 200:
                     self.log.error("Unable to request query cancel on athena. Exiting")
diff --git a/tests/providers/amazon/aws/hooks/test_athena.py b/tests/providers/amazon/aws/hooks/test_athena.py
index 65d549f925..58870bb1d0 100644
--- a/tests/providers/amazon/aws/hooks/test_athena.py
+++ b/tests/providers/amazon/aws/hooks/test_athena.py
@@ -92,6 +92,27 @@ class TestAthenaHook(unittest.TestCase):
         mock_conn.return_value.start_query_execution.assert_called_with(**expected_call_params)
         assert result == MOCK_DATA["query_execution_id"]
 
+    @mock.patch.object(AthenaHook, "log")
+    @mock.patch.object(AthenaHook, "get_conn")
+    def test_hook_run_query_log_query(self, mock_conn, log):
+        self.athena.run_query(
+            query=MOCK_DATA["query"],
+            query_context=mock_query_context,
+            result_configuration=mock_result_configuration,
+        )
+        assert self.athena.log.info.call_count == 2
+
+    @mock.patch.object(AthenaHook, "log")
+    @mock.patch.object(AthenaHook, "get_conn")
+    def test_hook_run_query_no_log_query(self, mock_conn, log):
+        athena_hook_no_log_query = AthenaHook(sleep_time=0, log_query=False)
+        athena_hook_no_log_query.run_query(
+            query=MOCK_DATA["query"],
+            query_context=mock_query_context,
+            result_configuration=mock_result_configuration,
+        )
+        assert athena_hook_no_log_query.log.info.call_count == 1
+
     @mock.patch.object(AthenaHook, "get_conn")
     def test_hook_get_query_results_with_non_succeeded_query(self, mock_conn):
         mock_conn.return_value.get_query_execution.return_value = MOCK_RUNNING_QUERY_EXECUTION