You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by on...@apache.org on 2023/03/01 02:27:36 UTC

[airflow] branch main updated: add num rows affected to Redshift Data API hook (#29797)

This is an automated email from the ASF dual-hosted git repository.

onikolas pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 1f7bc1ab3c add num rows affected to Redshift Data API hook (#29797)
1f7bc1ab3c is described below

commit 1f7bc1ab3c5bc5d51dda40197b52a111cb1f22ee
Author: Josh Dimarsky <24...@users.noreply.github.com>
AuthorDate: Tue Feb 28 21:27:19 2023 -0500

    add num rows affected to Redshift Data API hook (#29797)
---
 .../providers/amazon/aws/hooks/redshift_data.py    |  3 ++
 .../amazon/aws/hooks/test_redshift_data.py         | 42 ++++++++++++++++++++++
 2 files changed, 45 insertions(+)

diff --git a/airflow/providers/amazon/aws/hooks/redshift_data.py b/airflow/providers/amazon/aws/hooks/redshift_data.py
index e73c5a943a..6a8363522a 100644
--- a/airflow/providers/amazon/aws/hooks/redshift_data.py
+++ b/airflow/providers/amazon/aws/hooks/redshift_data.py
@@ -106,6 +106,9 @@ class RedshiftDataHook(AwsGenericHook["RedshiftDataAPIServiceClient"]):
             )
             status = resp["Status"]
             if status == "FINISHED":
+                num_rows = resp.get("ResultRows")
+                if num_rows is not None:
+                    self.log.info("Processed %s rows", num_rows)
                 return status
             elif status == "FAILED" or status == "ABORTED":
                 raise ValueError(
diff --git a/tests/providers/amazon/aws/hooks/test_redshift_data.py b/tests/providers/amazon/aws/hooks/test_redshift_data.py
index 29816442c4..bbc5295a2d 100644
--- a/tests/providers/amazon/aws/hooks/test_redshift_data.py
+++ b/tests/providers/amazon/aws/hooks/test_redshift_data.py
@@ -17,6 +17,7 @@
 # under the License.
 from __future__ import annotations
 
+import logging
 from unittest import mock
 
 from airflow.providers.amazon.aws.hooks.redshift_data import RedshiftDataHook
@@ -186,3 +187,44 @@ class TestRedshiftDataHook:
             (dict(Id=STATEMENT_ID, NextToken="token1"),),
             (dict(Id=STATEMENT_ID, NextToken="token2"),),
         ]
+
+    @mock.patch("airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.conn")
+    def test_result_num_rows(self, mock_conn, caplog):
+        cluster_identifier = "cluster_identifier"
+        db_user = "db_user"
+        secret_arn = "secret_arn"
+        statement_name = "statement_name"
+        parameters = [{"name": "id", "value": "1"}]
+        mock_conn.execute_statement.return_value = {"Id": STATEMENT_ID}
+        mock_conn.describe_statement.return_value = {"Status": "FINISHED", "ResultRows": 123}
+
+        hook = RedshiftDataHook(aws_conn_id=CONN_ID, region_name="us-east-1")
+        # https://docs.pytest.org/en/stable/how-to/logging.html
+        with caplog.at_level(logging.INFO):
+            hook.execute_query(
+                sql=SQL,
+                database=DATABASE,
+                cluster_identifier=cluster_identifier,
+                db_user=db_user,
+                secret_arn=secret_arn,
+                statement_name=statement_name,
+                parameters=parameters,
+                wait_for_completion=True,
+            )
+            assert "Processed 123 rows" in caplog.text
+
+        # ensure message is not there when `ResultRows` is not returned
+        caplog.clear()
+        mock_conn.describe_statement.return_value = {"Status": "FINISHED"}
+        with caplog.at_level(logging.INFO):
+            hook.execute_query(
+                sql=SQL,
+                database=DATABASE,
+                cluster_identifier=cluster_identifier,
+                db_user=db_user,
+                secret_arn=secret_arn,
+                statement_name=statement_name,
+                parameters=parameters,
+                wait_for_completion=True,
+            )
+            assert "Processed " not in caplog.text