You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by ta...@apache.org on 2023/03/02 09:49:10 UTC
[airflow] branch main updated: Impovements for RedshiftDataOperator: better error reporting and an ability to return SQL results (#29434)
This is an automated email from the ASF dual-hosted git repository.
taragolis pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new ea8ce218b9 Impovements for RedshiftDataOperator: better error reporting and an ability to return SQL results (#29434)
ea8ce218b9 is described below
commit ea8ce218b9abe3c69f4c2d8c65180cf8bafebdd6
Author: Yuriy Badalyants <lm...@gmail.com>
AuthorDate: Thu Mar 2 22:48:47 2023 +1300
Impovements for RedshiftDataOperator: better error reporting and an ability to return SQL results (#29434)
---
.../providers/amazon/aws/hooks/redshift_data.py | 5 +--
.../amazon/aws/operators/redshift_data.py | 16 ++++++++--
.../amazon/aws/operators/test_redshift_data.py | 36 ++++++++++++++++++++++
3 files changed, 53 insertions(+), 4 deletions(-)
diff --git a/airflow/providers/amazon/aws/hooks/redshift_data.py b/airflow/providers/amazon/aws/hooks/redshift_data.py
index 6a8363522a..e033624c4c 100644
--- a/airflow/providers/amazon/aws/hooks/redshift_data.py
+++ b/airflow/providers/amazon/aws/hooks/redshift_data.py
@@ -17,6 +17,7 @@
# under the License.
from __future__ import annotations
+from pprint import pformat
from time import sleep
from typing import TYPE_CHECKING, Any, Iterable
@@ -112,8 +113,8 @@ class RedshiftDataHook(AwsGenericHook["RedshiftDataAPIServiceClient"]):
return status
elif status == "FAILED" or status == "ABORTED":
raise ValueError(
- f"Statement {statement_id!r} terminated with status {status}, "
- f"error msg: {resp.get('Error')}"
+ f"Statement {statement_id!r} terminated with status {status}. "
+ f"Response details: {pformat(resp)}"
)
else:
self.log.info("Query %s", status)
diff --git a/airflow/providers/amazon/aws/operators/redshift_data.py b/airflow/providers/amazon/aws/operators/redshift_data.py
index 4ba329c64f..6d85262671 100644
--- a/airflow/providers/amazon/aws/operators/redshift_data.py
+++ b/airflow/providers/amazon/aws/operators/redshift_data.py
@@ -25,6 +25,8 @@ from airflow.models import BaseOperator
from airflow.providers.amazon.aws.hooks.redshift_data import RedshiftDataHook
if TYPE_CHECKING:
+ from mypy_boto3_redshift_data.type_defs import GetStatementResultResponseTypeDef
+
from airflow.utils.context import Context
@@ -46,6 +48,8 @@ class RedshiftDataOperator(BaseOperator):
:param with_event: indicates whether to send an event to EventBridge
:param wait_for_completion: indicates whether to wait for a result, if True wait, if False don't wait
:param poll_interval: how often in seconds to check the query status
+ :param return_sql_result: if True will return the result of an SQL statement,
+ if False (default) will return statement ID
:param aws_conn_id: aws connection to use
:param region: aws region to use
"""
@@ -62,6 +66,7 @@ class RedshiftDataOperator(BaseOperator):
)
template_ext = (".sql",)
template_fields_renderers = {"sql": "sql"}
+ statement_id: str | None
def __init__(
self,
@@ -75,6 +80,7 @@ class RedshiftDataOperator(BaseOperator):
with_event: bool = False,
wait_for_completion: bool = True,
poll_interval: int = 10,
+ return_sql_result: bool = False,
aws_conn_id: str = "aws_default",
region: str | None = None,
await_result: bool | None = None,
@@ -106,6 +112,7 @@ class RedshiftDataOperator(BaseOperator):
"Invalid poll_interval:",
poll_interval,
)
+ self.return_sql_result = return_sql_result
self.aws_conn_id = aws_conn_id
self.region = region
self.statement_id: str | None = None
@@ -166,7 +173,7 @@ class RedshiftDataOperator(BaseOperator):
)
return self.hook.wait_for_results(statement_id=statement_id, poll_interval=self.poll_interval)
- def execute(self, context: Context) -> str:
+ def execute(self, context: Context) -> GetStatementResultResponseTypeDef | str:
"""Execute a statement against Amazon Redshift"""
self.log.info("Executing statement: %s", self.sql)
@@ -183,7 +190,12 @@ class RedshiftDataOperator(BaseOperator):
poll_interval=self.poll_interval,
)
- return self.statement_id
+ if self.return_sql_result:
+ result = self.hook.conn.get_statement_result(Id=self.statement_id)
+ self.log.debug("Statement result: %s", result)
+ return result
+ else:
+ return self.statement_id
def on_kill(self) -> None:
"""Cancel the submitted redshift query"""
diff --git a/tests/providers/amazon/aws/operators/test_redshift_data.py b/tests/providers/amazon/aws/operators/test_redshift_data.py
index 6915c13c11..8cb2b0877f 100644
--- a/tests/providers/amazon/aws/operators/test_redshift_data.py
+++ b/tests/providers/amazon/aws/operators/test_redshift_data.py
@@ -111,3 +111,39 @@ class TestRedshiftDataOperator:
await_result=True,
)
assert op.wait_for_completion
+
+ @mock.patch("airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.conn")
+ def test_return_sql_result(self, mock_conn):
+ expected_result = {"Result": True}
+ mock_conn.execute_statement.return_value = {"Id": STATEMENT_ID}
+ mock_conn.describe_statement.return_value = {"Status": "FINISHED"}
+ mock_conn.get_statement_result.return_value = expected_result
+ cluster_identifier = "cluster_identifier"
+ db_user = "db_user"
+ secret_arn = "secret_arn"
+ statement_name = "statement_name"
+ operator = RedshiftDataOperator(
+ task_id=TASK_ID,
+ cluster_identifier=cluster_identifier,
+ database=DATABASE,
+ db_user=db_user,
+ sql=SQL,
+ statement_name=statement_name,
+ secret_arn=secret_arn,
+ aws_conn_id=CONN_ID,
+ return_sql_result=True,
+ )
+ actual_result = operator.execute(None)
+ assert actual_result == expected_result
+ mock_conn.execute_statement.assert_called_once_with(
+ Database=DATABASE,
+ Sql=SQL,
+ ClusterIdentifier=cluster_identifier,
+ DbUser=db_user,
+ SecretArn=secret_arn,
+ StatementName=statement_name,
+ WithEvent=False,
+ )
+ mock_conn.get_statement_result.assert_called_once_with(
+ Id=STATEMENT_ID,
+ )