You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by ka...@apache.org on 2022/08/10 11:15:04 UTC
[airflow] branch main updated: Enable multiple query execution in RedshiftDataOperator (#25619)
This is an automated email from the ASF dual-hosted git repository.
kaxilnaik pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 358593c6b6 Enable multiple query execution in RedshiftDataOperator (#25619)
358593c6b6 is described below
commit 358593c6b65620807103ae16946825e0bfad974f
Author: Pankaj Singh <98...@users.noreply.github.com>
AuthorDate: Wed Aug 10 16:44:56 2022 +0530
Enable multiple query execution in RedshiftDataOperator (#25619)
Enable RedshiftDataOperator to execute a batch of SQL using batch_execute_statement boto3 API.
---
.../amazon/aws/operators/redshift_data.py | 26 +++++++++++++++----
.../amazon/aws/operators/test_redshift_data.py | 29 ++++++++++++++++++++++
2 files changed, 50 insertions(+), 5 deletions(-)
diff --git a/airflow/providers/amazon/aws/operators/redshift_data.py b/airflow/providers/amazon/aws/operators/redshift_data.py
index a2400f94bd..41d1734789 100644
--- a/airflow/providers/amazon/aws/operators/redshift_data.py
+++ b/airflow/providers/amazon/aws/operators/redshift_data.py
@@ -16,7 +16,7 @@
# specific language governing permissions and limitations
# under the License.
from time import sleep
-from typing import TYPE_CHECKING, Any, Dict, Optional
+from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
from airflow.compat.functools import cached_property
from airflow.models import BaseOperator
@@ -36,7 +36,7 @@ class RedshiftDataOperator(BaseOperator):
:ref:`howto/operator:RedshiftDataOperator`
:param database: the name of the database
- :param sql: the SQL statement text to run
+ :param sql: the SQL statement or list of SQL statement to run
:param cluster_identifier: unique identifier of a cluster
:param db_user: the database username
:param parameters: the parameters for the SQL statement
@@ -65,7 +65,7 @@ class RedshiftDataOperator(BaseOperator):
def __init__(
self,
database: str,
- sql: str,
+ sql: Union[str, List],
cluster_identifier: Optional[str] = None,
db_user: Optional[str] = None,
parameters: Optional[list] = None,
@@ -119,6 +119,20 @@ class RedshiftDataOperator(BaseOperator):
resp = self.hook.conn.execute_statement(**trim_none_values(kwargs))
return resp['Id']
+ def execute_batch_query(self):
+ kwargs: Dict[str, Any] = {
+ "ClusterIdentifier": self.cluster_identifier,
+ "Database": self.database,
+ "Sqls": self.sql,
+ "DbUser": self.db_user,
+ "Parameters": self.parameters,
+ "WithEvent": self.with_event,
+ "SecretArn": self.secret_arn,
+ "StatementName": self.statement_name,
+ }
+ resp = self.hook.conn.batch_execute_statement(**trim_none_values(kwargs))
+ return resp['Id']
+
def wait_for_results(self, statement_id):
while True:
self.log.info("Polling statement %s", statement_id)
@@ -137,8 +151,10 @@ class RedshiftDataOperator(BaseOperator):
def execute(self, context: 'Context') -> None:
"""Execute a statement against Amazon Redshift"""
self.log.info("Executing statement: %s", self.sql)
-
- self.statement_id = self.execute_query()
+ if isinstance(self.sql, list):
+ self.statement_id = self.execute_batch_query()
+ else:
+ self.statement_id = self.execute_query()
if self.await_result:
self.wait_for_results(self.statement_id)
diff --git a/tests/providers/amazon/aws/operators/test_redshift_data.py b/tests/providers/amazon/aws/operators/test_redshift_data.py
index 9b2d13416f..f98a75c445 100644
--- a/tests/providers/amazon/aws/operators/test_redshift_data.py
+++ b/tests/providers/amazon/aws/operators/test_redshift_data.py
@@ -110,3 +110,32 @@ class TestRedshiftDataOperator:
mock_conn.cancel_statement.assert_called_once_with(
Id=STATEMENT_ID,
)
+
+ @mock.patch("airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.conn")
+ def test_batch_execute(self, mock_conn):
+ mock_conn.execute_statement.return_value = {'Id': STATEMENT_ID}
+ mock_conn.describe_statement.return_value = {"Status": "FINISHED"}
+ cluster_identifier = "cluster_identifier"
+ db_user = "db_user"
+ secret_arn = "secret_arn"
+ statement_name = "statement_name"
+ operator = RedshiftDataOperator(
+ task_id=TASK_ID,
+ cluster_identifier=cluster_identifier,
+ database=DATABASE,
+ db_user=db_user,
+ sql=[SQL],
+ statement_name=statement_name,
+ secret_arn=secret_arn,
+ aws_conn_id=CONN_ID,
+ )
+ operator.execute(None)
+ mock_conn.batch_execute_statement.assert_called_once_with(
+ Database=DATABASE,
+ Sqls=[SQL],
+ ClusterIdentifier=cluster_identifier,
+ DbUser=db_user,
+ SecretArn=secret_arn,
+ StatementName=statement_name,
+ WithEvent=False,
+ )