You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by ka...@apache.org on 2022/08/10 11:15:04 UTC

[airflow] branch main updated: Enable multiple query execution in RedshiftDataOperator (#25619)

This is an automated email from the ASF dual-hosted git repository.

kaxilnaik pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 358593c6b6 Enable multiple query execution in RedshiftDataOperator (#25619)
358593c6b6 is described below

commit 358593c6b65620807103ae16946825e0bfad974f
Author: Pankaj Singh <98...@users.noreply.github.com>
AuthorDate: Wed Aug 10 16:44:56 2022 +0530

    Enable multiple query execution in RedshiftDataOperator (#25619)
    
    Enable RedshiftDataOperator to execute a batch of SQL using batch_execute_statement boto3 API.
---
 .../amazon/aws/operators/redshift_data.py          | 26 +++++++++++++++----
 .../amazon/aws/operators/test_redshift_data.py     | 29 ++++++++++++++++++++++
 2 files changed, 50 insertions(+), 5 deletions(-)

diff --git a/airflow/providers/amazon/aws/operators/redshift_data.py b/airflow/providers/amazon/aws/operators/redshift_data.py
index a2400f94bd..41d1734789 100644
--- a/airflow/providers/amazon/aws/operators/redshift_data.py
+++ b/airflow/providers/amazon/aws/operators/redshift_data.py
@@ -16,7 +16,7 @@
 # specific language governing permissions and limitations
 # under the License.
 from time import sleep
-from typing import TYPE_CHECKING, Any, Dict, Optional
+from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
 
 from airflow.compat.functools import cached_property
 from airflow.models import BaseOperator
@@ -36,7 +36,7 @@ class RedshiftDataOperator(BaseOperator):
         :ref:`howto/operator:RedshiftDataOperator`
 
     :param database: the name of the database
-    :param sql: the SQL statement text to run
+    :param sql: the SQL statement or list of  SQL statement to run
     :param cluster_identifier: unique identifier of a cluster
     :param db_user: the database username
     :param parameters: the parameters for the SQL statement
@@ -65,7 +65,7 @@ class RedshiftDataOperator(BaseOperator):
     def __init__(
         self,
         database: str,
-        sql: str,
+        sql: Union[str, List],
         cluster_identifier: Optional[str] = None,
         db_user: Optional[str] = None,
         parameters: Optional[list] = None,
@@ -119,6 +119,20 @@ class RedshiftDataOperator(BaseOperator):
         resp = self.hook.conn.execute_statement(**trim_none_values(kwargs))
         return resp['Id']
 
+    def execute_batch_query(self):
+        kwargs: Dict[str, Any] = {
+            "ClusterIdentifier": self.cluster_identifier,
+            "Database": self.database,
+            "Sqls": self.sql,
+            "DbUser": self.db_user,
+            "Parameters": self.parameters,
+            "WithEvent": self.with_event,
+            "SecretArn": self.secret_arn,
+            "StatementName": self.statement_name,
+        }
+        resp = self.hook.conn.batch_execute_statement(**trim_none_values(kwargs))
+        return resp['Id']
+
     def wait_for_results(self, statement_id):
         while True:
             self.log.info("Polling statement %s", statement_id)
@@ -137,8 +151,10 @@ class RedshiftDataOperator(BaseOperator):
     def execute(self, context: 'Context') -> None:
         """Execute a statement against Amazon Redshift"""
         self.log.info("Executing statement: %s", self.sql)
-
-        self.statement_id = self.execute_query()
+        if isinstance(self.sql, list):
+            self.statement_id = self.execute_batch_query()
+        else:
+            self.statement_id = self.execute_query()
 
         if self.await_result:
             self.wait_for_results(self.statement_id)
diff --git a/tests/providers/amazon/aws/operators/test_redshift_data.py b/tests/providers/amazon/aws/operators/test_redshift_data.py
index 9b2d13416f..f98a75c445 100644
--- a/tests/providers/amazon/aws/operators/test_redshift_data.py
+++ b/tests/providers/amazon/aws/operators/test_redshift_data.py
@@ -110,3 +110,32 @@ class TestRedshiftDataOperator:
         mock_conn.cancel_statement.assert_called_once_with(
             Id=STATEMENT_ID,
         )
+
+    @mock.patch("airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.conn")
+    def test_batch_execute(self, mock_conn):
+        mock_conn.execute_statement.return_value = {'Id': STATEMENT_ID}
+        mock_conn.describe_statement.return_value = {"Status": "FINISHED"}
+        cluster_identifier = "cluster_identifier"
+        db_user = "db_user"
+        secret_arn = "secret_arn"
+        statement_name = "statement_name"
+        operator = RedshiftDataOperator(
+            task_id=TASK_ID,
+            cluster_identifier=cluster_identifier,
+            database=DATABASE,
+            db_user=db_user,
+            sql=[SQL],
+            statement_name=statement_name,
+            secret_arn=secret_arn,
+            aws_conn_id=CONN_ID,
+        )
+        operator.execute(None)
+        mock_conn.batch_execute_statement.assert_called_once_with(
+            Database=DATABASE,
+            Sqls=[SQL],
+            ClusterIdentifier=cluster_identifier,
+            DbUser=db_user,
+            SecretArn=secret_arn,
+            StatementName=statement_name,
+            WithEvent=False,
+        )