You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by GitBox <gi...@apache.org> on 2022/12/02 20:38:14 UTC

[GitHub] [airflow] vincbeck commented on a diff in pull request #27947: refactored Amazon Redshift-data functionality into the hook

vincbeck commented on code in PR #27947:
URL: https://github.com/apache/airflow/pull/27947#discussion_r1038511636


##########
airflow/providers/amazon/aws/hooks/redshift_data.py:
##########
@@ -46,3 +48,73 @@ class RedshiftDataHook(AwsGenericHook["RedshiftDataAPIServiceClient"]):
     def __init__(self, *args, **kwargs) -> None:
         kwargs["client_type"] = "redshift-data"
         super().__init__(*args, **kwargs)
+
+    def execute_query(
+        self,
+        database: str,
+        sql: str | list[str],
+        cluster_identifier: str | None = None,
+        db_user: str | None = None,
+        parameters: list | None = None,
+        secret_arn: str | None = None,
+        statement_name: str | None = None,
+        with_event: bool = False,
+        await_result: bool = True,

Review Comment:
   By convention, `wait_for_completion` is usually used as name for this kind of flag



##########
airflow/providers/amazon/aws/hooks/redshift_data.py:
##########
@@ -46,3 +48,73 @@ class RedshiftDataHook(AwsGenericHook["RedshiftDataAPIServiceClient"]):
     def __init__(self, *args, **kwargs) -> None:
         kwargs["client_type"] = "redshift-data"
         super().__init__(*args, **kwargs)
+
+    def execute_query(
+        self,
+        database: str,
+        sql: str | list[str],
+        cluster_identifier: str | None = None,
+        db_user: str | None = None,
+        parameters: list | None = None,
+        secret_arn: str | None = None,
+        statement_name: str | None = None,
+        with_event: bool = False,
+        await_result: bool = True,
+        poll_interval: int = 10,
+    ) -> str:
+        """
+        Execute a statement against Amazon Redshift
+
+        :param database: the name of the database
+        :param sql: the SQL statement or list of  SQL statement to run
+        :param cluster_identifier: unique identifier of a cluster
+        :param db_user: the database username
+        :param parameters: the parameters for the SQL statement
+        :param secret_arn: the name or ARN of the secret that enables db access
+        :param statement_name: the name of the SQL statement
+        :param with_event: indicates whether to send an event to EventBridge
+        :param await_result: indicates whether to wait for a result, if True wait, if False don't wait
+        :param poll_interval: how often in seconds to check the query status
+
+        :returns statement_id: str, the UUID of the statement
+        """
+        kwargs: dict[str, Any] = {
+            "ClusterIdentifier": cluster_identifier,
+            "Database": database,
+            "DbUser": db_user,
+            "Parameters": parameters,
+            "WithEvent": with_event,
+            "SecretArn": secret_arn,
+            "StatementName": statement_name,
+        }
+        if isinstance(sql, list):
+            kwargs["Sqls"] = sql
+            resp = self.conn.batch_execute_statement(**trim_none_values(kwargs))
+        else:
+            kwargs["Sql"] = sql
+            resp = self.conn.execute_statement(**trim_none_values(kwargs))
+
+        statement_id = resp["Id"]
+
+        if await_result:
+            self.wait_for_results(statement_id, poll_interval=poll_interval)
+
+        return statement_id
+
+    def wait_for_results(self, statement_id, poll_interval):
+        while True:
+            self.log.info("Polling statement %s", statement_id)
+            resp = self.conn.describe_statement(
+                Id=statement_id,
+            )
+            status = resp["Status"]
+            if status == "FINISHED":
+                return status
+            elif status == "FAILED" or status == "ABORTED":
+                raise ValueError(
+                    f"Statement {statement_id!r} terminated with status {status}, "
+                    f"error msg: {resp.get('Error')}"
+                )
+            else:
+                self.log.info("Query %s", status)

Review Comment:
   ```suggestion
               self.log.info("Query %s", status)
   ```



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@airflow.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org