You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by po...@apache.org on 2022/03/11 11:49:07 UTC

[airflow] branch main updated: Fix RedshiftDataOperator and update doc (#22157)

This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 46a120d  Fix RedshiftDataOperator and update doc (#22157)
46a120d is described below

commit 46a120dc5f37d0a38cbfe338af215dc63e590aff
Author: vincbeck <97...@users.noreply.github.com>
AuthorDate: Fri Mar 11 06:47:59 2022 -0500

    Fix RedshiftDataOperator and update doc (#22157)
---
 .../example_redshift_data_execute_sql.py           | 58 ++++++++++------------
 .../amazon/aws/operators/redshift_data.py          | 25 ++++++----
 .../operators/redshift_data.rst                    | 39 +++++++--------
 .../amazon/aws/operators/test_redshift_data.py     | 26 ++++++----
 4 files changed, 73 insertions(+), 75 deletions(-)

diff --git a/airflow/providers/amazon/aws/example_dags/example_redshift_data_execute_sql.py b/airflow/providers/amazon/aws/example_dags/example_redshift_data_execute_sql.py
index 4806d66..38ac3ef 100644
--- a/airflow/providers/amazon/aws/example_dags/example_redshift_data_execute_sql.py
+++ b/airflow/providers/amazon/aws/example_dags/example_redshift_data_execute_sql.py
@@ -15,18 +15,17 @@
 # specific language governing permissions and limitations
 # under the License.
 
-from datetime import datetime, timedelta
+from datetime import datetime
 from os import getenv
 
-from airflow.decorators import dag, task
+from airflow import DAG
+from airflow.decorators import task
 from airflow.providers.amazon.aws.hooks.redshift_data import RedshiftDataHook
 from airflow.providers.amazon.aws.operators.redshift_data import RedshiftDataOperator
 
-# [START howto_operator_redshift_data_env_variables]
-REDSHIFT_CLUSTER_IDENTIFIER = getenv("REDSHIFT_CLUSTER_IDENTIFIER", "test-cluster")
-REDSHIFT_DATABASE = getenv("REDSHIFT_DATABASE", "test-database")
+REDSHIFT_CLUSTER_IDENTIFIER = getenv("REDSHIFT_CLUSTER_IDENTIFIER", "redshift_cluster_identifier")
+REDSHIFT_DATABASE = getenv("REDSHIFT_DATABASE", "redshift_database")
 REDSHIFT_DATABASE_USER = getenv("REDSHIFT_DATABASE_USER", "awsuser")
-# [END howto_operator_redshift_data_env_variables]
 
 REDSHIFT_QUERY = """
 SELECT table_schema,
@@ -40,29 +39,26 @@ ORDER BY table_schema,
 POLL_INTERVAL = 10
 
 
-# [START howto_redshift_data]
-@dag(
-    dag_id='example_redshift_data',
-    schedule_interval=None,
-    start_date=datetime(2021, 1, 1),
-    dagrun_timeout=timedelta(minutes=60),
-    tags=['example'],
-    catchup=False,
-)
-def example_redshift_data():
-    @task(task_id="output_results")
-    def output_results_fn(id):
-        """This is a python decorator task that returns a Redshift query"""
-        hook = RedshiftDataHook()
+@task(task_id="output_results")
+def output_query_results(statement_id):
+    hook = RedshiftDataHook()
+    resp = hook.conn.get_statement_result(
+        Id=statement_id,
+    )
+
+    print(resp)
+    return resp
 
-        resp = hook.get_statement_result(
-            id=id,
-        )
-        print(resp)
-        return resp
 
-    # Run a SQL statement and wait for completion
-    redshift_query = RedshiftDataOperator(
+with DAG(
+    dag_id="example_redshift_data_execute_sql",
+    start_date=datetime(2021, 1, 1),
+    schedule_interval=None,
+    catchup=False,
+    tags=['example'],
+) as dag:
+    # [START howto_redshift_data]
+    task_query = RedshiftDataOperator(
         task_id='redshift_query',
         cluster_identifier=REDSHIFT_CLUSTER_IDENTIFIER,
         database=REDSHIFT_DATABASE,
@@ -71,10 +67,6 @@ def example_redshift_data():
         poll_interval=POLL_INTERVAL,
         await_result=True,
     )
+    # [END howto_redshift_data]
 
-    # Using a task-decorated function to output the list of tables in a Redshift cluster
-    output_results_fn(redshift_query.output)
-
-
-example_redshift_data_dag = example_redshift_data()
-# [END howto_redshift_data]
+    task_output = output_query_results(task_query.output)
diff --git a/airflow/providers/amazon/aws/operators/redshift_data.py b/airflow/providers/amazon/aws/operators/redshift_data.py
index 977bc68..3961833 100644
--- a/airflow/providers/amazon/aws/operators/redshift_data.py
+++ b/airflow/providers/amazon/aws/operators/redshift_data.py
@@ -17,7 +17,7 @@
 # under the License.
 import sys
 from time import sleep
-from typing import TYPE_CHECKING, Optional
+from typing import TYPE_CHECKING, Any, Dict, Optional
 
 if sys.version_info >= (3, 8):
     from functools import cached_property
@@ -109,16 +109,19 @@ class RedshiftDataOperator(BaseOperator):
         return RedshiftDataHook(aws_conn_id=self.aws_conn_id, region_name=self.region)
 
     def execute_query(self):
-        resp = self.hook.conn.execute_statement(
-            ClusterIdentifier=self.cluster_identifier,
-            Database=self.database,
-            DbUser=self.db_user,
-            Sql=self.sql,
-            Parameters=self.parameters,
-            SecretArn=self.secret_arn,
-            StatementName=self.statement_name,
-            WithEvent=self.with_event,
-        )
+        kwargs: Dict[str, Any] = {
+            "ClusterIdentifier": self.cluster_identifier,
+            "Database": self.database,
+            "Sql": self.sql,
+            "DbUser": self.db_user,
+            "Parameters": self.parameters,
+            "WithEvent": self.with_event,
+            "SecretArn": self.secret_arn,
+            "StatementName": self.statement_name,
+        }
+
+        filter_values = {key: val for key, val in kwargs.items() if val is not None}
+        resp = self.hook.conn.execute_statement(**filter_values)
         return resp['Id']
 
     def wait_for_results(self, statement_id):
diff --git a/docs/apache-airflow-providers-amazon/operators/redshift_data.rst b/docs/apache-airflow-providers-amazon/operators/redshift_data.rst
index 7362008..4080378 100644
--- a/docs/apache-airflow-providers-amazon/operators/redshift_data.rst
+++ b/docs/apache-airflow-providers-amazon/operators/redshift_data.rst
@@ -15,38 +15,37 @@
     specific language governing permissions and limitations
     under the License.
 
-.. _howto/operator:RedshiftDataOperator:
+Amazon Redshift Data Operators
+==============================
 
-RedshiftDataOperator
-====================
+Use the :class:`RedshiftDataOperator <airflow.providers.amazon.aws.operators.redshift_data>` to execute
+statements against an Amazon Redshift cluster.
 
-.. contents::
-  :depth: 1
-  :local:
+This differs from ``RedshiftSQLOperator`` in that it allows users to query and retrieve data via the AWS API and avoid the necessity of a Postgres connection.
 
-Overview
---------
+Prerequisite Tasks
+^^^^^^^^^^^^^^^^^^
 
-Use the :class:`RedshiftDataOperator <airflow.providers.amazon.aws.operators.redshift_data>` to execute
-statements against an Amazon Redshift cluster.
+.. include:: _partials/prerequisite_tasks.rst
 
-This differs from RedshiftSQLOperator in that it allows users to query and retrieve data via the AWS API and avoid the necessity of a Postgres connection.
+Amazon Redshift Data
+^^^^^^^^^^^^^^^^^^^^
 
-example_redshift_data_execute_sql.py
-------------------------------------
+.. _howto/operator:RedshiftDataOperator:
 
-Purpose
-"""""""
+Execute a statement on an Amazon Redshift Cluster
+"""""""""""""""""""""""""""""""""""""""""""""""""
 
 This is a basic example DAG for using :class:`RedshiftDataOperator <airflow.providers.amazon.aws.operators.redshift_data>`
 to execute statements against an Amazon Redshift cluster.
 
-List tables in database
-"""""""""""""""""""""""
-
-In the following code we list the tables in the provided database.
-
 .. exampleinclude:: /../../airflow/providers/amazon/aws/example_dags/example_redshift_data_execute_sql.py
     :language: python
+    :dedent: 4
     :start-after: [START howto_redshift_data]
     :end-before: [END howto_redshift_data]
+
+Reference
+^^^^^^^^^
+
+ * `AWS boto3 Library Documentation for Amazon Redshift Data <https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/redshift-data.html>`__
diff --git a/tests/providers/amazon/aws/operators/test_redshift_data.py b/tests/providers/amazon/aws/operators/test_redshift_data.py
index c6a6943..9b2d134 100644
--- a/tests/providers/amazon/aws/operators/test_redshift_data.py
+++ b/tests/providers/amazon/aws/operators/test_redshift_data.py
@@ -40,38 +40,42 @@ class TestRedshiftDataOperator:
         )
         operator.execute(None)
         mock_conn.execute_statement.assert_called_once_with(
-            ClusterIdentifier=None,
             Database=DATABASE,
-            DbUser=None,
             Sql=SQL,
-            Parameters=None,
-            SecretArn=None,
-            StatementName=None,
             WithEvent=False,
         )
         mock_conn.describe_statement.assert_not_called()
 
     @mock.patch("airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.conn")
-    def test_execute(self, mock_conn):
+    def test_execute_with_all_parameters(self, mock_conn):
+        cluster_identifier = "cluster_identifier"
+        db_user = "db_user"
+        secret_arn = "secret_arn"
+        statement_name = "statement_name"
         parameters = [{"name": "id", "value": "1"}]
         mock_conn.execute_statement.return_value = {'Id': STATEMENT_ID}
         mock_conn.describe_statement.return_value = {"Status": "FINISHED"}
+
         operator = RedshiftDataOperator(
             aws_conn_id=CONN_ID,
             task_id=TASK_ID,
             sql=SQL,
-            parameters=parameters,
             database=DATABASE,
+            cluster_identifier=cluster_identifier,
+            db_user=db_user,
+            secret_arn=secret_arn,
+            statement_name=statement_name,
+            parameters=parameters,
         )
         operator.execute(None)
         mock_conn.execute_statement.assert_called_once_with(
-            ClusterIdentifier=None,
             Database=DATABASE,
-            DbUser=None,
             Sql=SQL,
+            ClusterIdentifier=cluster_identifier,
+            DbUser=db_user,
+            SecretArn=secret_arn,
+            StatementName=statement_name,
             Parameters=parameters,
-            SecretArn=None,
-            StatementName=None,
             WithEvent=False,
         )
         mock_conn.describe_statement.assert_called_once_with(