You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by ta...@apache.org on 2023/03/02 09:49:10 UTC

[airflow] branch main updated: Impovements for RedshiftDataOperator: better error reporting and an ability to return SQL results (#29434)

This is an automated email from the ASF dual-hosted git repository.

taragolis pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new ea8ce218b9 Impovements for RedshiftDataOperator: better error reporting and an ability to return SQL results (#29434)
ea8ce218b9 is described below

commit ea8ce218b9abe3c69f4c2d8c65180cf8bafebdd6
Author: Yuriy Badalyants <lm...@gmail.com>
AuthorDate: Thu Mar 2 22:48:47 2023 +1300

    Impovements for RedshiftDataOperator: better error reporting and an ability to return SQL results (#29434)
---
 .../providers/amazon/aws/hooks/redshift_data.py    |  5 +--
 .../amazon/aws/operators/redshift_data.py          | 16 ++++++++--
 .../amazon/aws/operators/test_redshift_data.py     | 36 ++++++++++++++++++++++
 3 files changed, 53 insertions(+), 4 deletions(-)

diff --git a/airflow/providers/amazon/aws/hooks/redshift_data.py b/airflow/providers/amazon/aws/hooks/redshift_data.py
index 6a8363522a..e033624c4c 100644
--- a/airflow/providers/amazon/aws/hooks/redshift_data.py
+++ b/airflow/providers/amazon/aws/hooks/redshift_data.py
@@ -17,6 +17,7 @@
 # under the License.
 from __future__ import annotations
 
+from pprint import pformat
 from time import sleep
 from typing import TYPE_CHECKING, Any, Iterable
 
@@ -112,8 +113,8 @@ class RedshiftDataHook(AwsGenericHook["RedshiftDataAPIServiceClient"]):
                 return status
             elif status == "FAILED" or status == "ABORTED":
                 raise ValueError(
-                    f"Statement {statement_id!r} terminated with status {status}, "
-                    f"error msg: {resp.get('Error')}"
+                    f"Statement {statement_id!r} terminated with status {status}. "
+                    f"Response details: {pformat(resp)}"
                 )
             else:
                 self.log.info("Query %s", status)
diff --git a/airflow/providers/amazon/aws/operators/redshift_data.py b/airflow/providers/amazon/aws/operators/redshift_data.py
index 4ba329c64f..6d85262671 100644
--- a/airflow/providers/amazon/aws/operators/redshift_data.py
+++ b/airflow/providers/amazon/aws/operators/redshift_data.py
@@ -25,6 +25,8 @@ from airflow.models import BaseOperator
 from airflow.providers.amazon.aws.hooks.redshift_data import RedshiftDataHook
 
 if TYPE_CHECKING:
+    from mypy_boto3_redshift_data.type_defs import GetStatementResultResponseTypeDef
+
     from airflow.utils.context import Context
 
 
@@ -46,6 +48,8 @@ class RedshiftDataOperator(BaseOperator):
     :param with_event: indicates whether to send an event to EventBridge
     :param wait_for_completion: indicates whether to wait for a result, if True wait, if False don't wait
     :param poll_interval: how often in seconds to check the query status
+    :param return_sql_result: if True will return the result of an SQL statement,
+        if False (default) will return statement ID
     :param aws_conn_id: aws connection to use
     :param region: aws region to use
     """
@@ -62,6 +66,7 @@ class RedshiftDataOperator(BaseOperator):
     )
     template_ext = (".sql",)
     template_fields_renderers = {"sql": "sql"}
+    statement_id: str | None
 
     def __init__(
         self,
@@ -75,6 +80,7 @@ class RedshiftDataOperator(BaseOperator):
         with_event: bool = False,
         wait_for_completion: bool = True,
         poll_interval: int = 10,
+        return_sql_result: bool = False,
         aws_conn_id: str = "aws_default",
         region: str | None = None,
         await_result: bool | None = None,
@@ -106,6 +112,7 @@ class RedshiftDataOperator(BaseOperator):
                 "Invalid poll_interval:",
                 poll_interval,
             )
+        self.return_sql_result = return_sql_result
         self.aws_conn_id = aws_conn_id
         self.region = region
         self.statement_id: str | None = None
@@ -166,7 +173,7 @@ class RedshiftDataOperator(BaseOperator):
         )
         return self.hook.wait_for_results(statement_id=statement_id, poll_interval=self.poll_interval)
 
-    def execute(self, context: Context) -> str:
+    def execute(self, context: Context) -> GetStatementResultResponseTypeDef | str:
         """Execute a statement against Amazon Redshift"""
         self.log.info("Executing statement: %s", self.sql)
 
@@ -183,7 +190,12 @@ class RedshiftDataOperator(BaseOperator):
             poll_interval=self.poll_interval,
         )
 
-        return self.statement_id
+        if self.return_sql_result:
+            result = self.hook.conn.get_statement_result(Id=self.statement_id)
+            self.log.debug("Statement result: %s", result)
+            return result
+        else:
+            return self.statement_id
 
     def on_kill(self) -> None:
         """Cancel the submitted redshift query"""
diff --git a/tests/providers/amazon/aws/operators/test_redshift_data.py b/tests/providers/amazon/aws/operators/test_redshift_data.py
index 6915c13c11..8cb2b0877f 100644
--- a/tests/providers/amazon/aws/operators/test_redshift_data.py
+++ b/tests/providers/amazon/aws/operators/test_redshift_data.py
@@ -111,3 +111,39 @@ class TestRedshiftDataOperator:
                 await_result=True,
             )
         assert op.wait_for_completion
+
+    @mock.patch("airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.conn")
+    def test_return_sql_result(self, mock_conn):
+        expected_result = {"Result": True}
+        mock_conn.execute_statement.return_value = {"Id": STATEMENT_ID}
+        mock_conn.describe_statement.return_value = {"Status": "FINISHED"}
+        mock_conn.get_statement_result.return_value = expected_result
+        cluster_identifier = "cluster_identifier"
+        db_user = "db_user"
+        secret_arn = "secret_arn"
+        statement_name = "statement_name"
+        operator = RedshiftDataOperator(
+            task_id=TASK_ID,
+            cluster_identifier=cluster_identifier,
+            database=DATABASE,
+            db_user=db_user,
+            sql=SQL,
+            statement_name=statement_name,
+            secret_arn=secret_arn,
+            aws_conn_id=CONN_ID,
+            return_sql_result=True,
+        )
+        actual_result = operator.execute(None)
+        assert actual_result == expected_result
+        mock_conn.execute_statement.assert_called_once_with(
+            Database=DATABASE,
+            Sql=SQL,
+            ClusterIdentifier=cluster_identifier,
+            DbUser=db_user,
+            SecretArn=secret_arn,
+            StatementName=statement_name,
+            WithEvent=False,
+        )
+        mock_conn.get_statement_result.assert_called_once_with(
+            Id=STATEMENT_ID,
+        )