You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by "Lee-W (via GitHub)" <gi...@apache.org> on 2023/07/17 15:15:45 UTC

[GitHub] [airflow] Lee-W commented on a diff in pull request #32437: add deferrable mode to rds start & stop DB

Lee-W commented on code in PR #32437:
URL: https://github.com/apache/airflow/pull/32437#discussion_r1265520231


##########
airflow/providers/amazon/aws/operators/rds.py:
##########
@@ -791,23 +829,47 @@ def __init__(
         db_identifier: str,
         db_type: RdsDbType | str = RdsDbType.INSTANCE,
         db_snapshot_identifier: str | None = None,
-        aws_conn_id: str = "aws_default",
         wait_for_completion: bool = True,
+        waiter_delay: int = 30,
+        waiter_max_attempts: int = 40,
+        deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
         **kwargs,
     ):
-        super().__init__(aws_conn_id=aws_conn_id, **kwargs)
+        super().__init__(**kwargs)
         self.db_identifier = db_identifier
         self.db_type = db_type
         self.db_snapshot_identifier = db_snapshot_identifier
         self.wait_for_completion = wait_for_completion
+        self.waiter_delay = waiter_delay
+        self.waiter_max_attempts = waiter_max_attempts
+        self.deferrable = deferrable
 
     def execute(self, context: Context) -> str:
         self.db_type = RdsDbType(self.db_type)
-        stop_db_response = self._stop_db()
-        if self.wait_for_completion:
+        stop_db_response: dict[str, Any] = self._stop_db()
+        if self.deferrable:
+            self.defer(
+                trigger=RdsDbStoppedTrigger(
+                    db_identifier=self.db_identifier,
+                    waiter_delay=self.waiter_delay,
+                    waiter_max_attempts=self.waiter_max_attempts,
+                    aws_conn_id=self.aws_conn_id,
+                    region_name=self.region_name,
+                    response=stop_db_response,
+                    db_type=RdsDbType.INSTANCE,
+                ),
+                method_name="execute_complete",
+            )
+        elif self.wait_for_completion:
             self._wait_until_db_stopped()
         return json.dumps(stop_db_response, default=str)
 
+    def execute_complete(self, context, event=None) -> str:

Review Comment:
   nitpick: missing type annotation



##########
airflow/providers/amazon/aws/operators/rds.py:
##########
@@ -62,13 +67,17 @@ def __init__(
                 AirflowProviderDeprecationWarning,
                 stacklevel=3,  # 2 is in the operator's init, 3 is in the user code creating the operator
             )
-        hook_params = hook_params or {}
-        self.region_name = region_name or hook_params.pop("region_name", None)
-        self.hook = RdsHook(aws_conn_id=aws_conn_id, region_name=self.region_name, **(hook_params))
+        self.hook_params = hook_params or {}
+        self.aws_conn_id = aws_conn_id
+        self.region_name = region_name or self.hook_params.pop("region_name", None)
         super().__init__(*args, **kwargs)
 
         self._await_interval = 60  # seconds
 
+    @cachedproperty
+    def hook(self):

Review Comment:
   ```suggestion
       def hook(self) -> RdsHook:
   ```



##########
airflow/providers/amazon/aws/triggers/rds.py:
##########
@@ -87,3 +93,151 @@ async def run(self):
                 status_args=["DBInstances[0].DBInstanceStatus"],
             )
         yield TriggerEvent({"status": "success", "response": self.response})
+
+
+_waiter_arg = {
+    RdsDbType.INSTANCE: "DBInstanceIdentifier",
+    RdsDbType.CLUSTER: "DBClusterIdentifier",
+}
+_status_paths = {
+    RdsDbType.INSTANCE: ["DBInstances[].DBInstanceStatus", "DBInstances[].StatusInfos"],
+    RdsDbType.CLUSTER: ["DBClusters[].Status"],
+}
+
+
+class RdsDbAvailableTrigger(AwsBaseWaiterTrigger):
+    """
+    Trigger to wait asynchronously for a DB instance or cluster to be available.
+
+    :param db_identifier: The DB identifier for the DB instance or cluster to be polled.
+    :param waiter_delay: The amount of time in seconds to wait between attempts.
+    :param waiter_max_attempts: The maximum number of attempts to be made.
+    :param aws_conn_id: The Airflow connection used for AWS credentials.
+    :param region_name: AWS region where the DB is located, if different from the default one.
+    :param response: The response from the RdsHook, to be passed back to the operator.
+    :param db_type: The type of DB: instance or cluster.
+    """
+
+    def __init__(
+        self,
+        db_identifier: str,
+        waiter_delay: int,
+        waiter_max_attempts: int,
+        aws_conn_id: str,
+        region_name: str | None,
+        response: dict[str, Any],
+        db_type: RdsDbType,
+    ):

Review Comment:
   ```suggestion
       ) -> None:
   ```



##########
airflow/providers/amazon/aws/triggers/rds.py:
##########
@@ -18,11 +18,17 @@
 
 from typing import Any
 
+from deprecated import deprecated
+
+from airflow.providers.amazon.aws.hooks.base_aws import AwsGenericHook
 from airflow.providers.amazon.aws.hooks.rds import RdsHook
+from airflow.providers.amazon.aws.triggers.base import AwsBaseWaiterTrigger
+from airflow.providers.amazon.aws.utils.rds import RdsDbType
 from airflow.providers.amazon.aws.utils.waiter_with_logging import async_wait
 from airflow.triggers.base import BaseTrigger, TriggerEvent
 
 
+@deprecated(reason="Use the other specialized RDS triggers")

Review Comment:
   Should we add a guide somewhere so that the users have a better idea of what `other specialized RDS triggers` mean and how to migrate?



##########
airflow/providers/amazon/aws/triggers/rds.py:
##########
@@ -87,3 +93,151 @@ async def run(self):
                 status_args=["DBInstances[0].DBInstanceStatus"],
             )
         yield TriggerEvent({"status": "success", "response": self.response})
+
+
+_waiter_arg = {
+    RdsDbType.INSTANCE: "DBInstanceIdentifier",
+    RdsDbType.CLUSTER: "DBClusterIdentifier",
+}
+_status_paths = {
+    RdsDbType.INSTANCE: ["DBInstances[].DBInstanceStatus", "DBInstances[].StatusInfos"],
+    RdsDbType.CLUSTER: ["DBClusters[].Status"],
+}
+
+
+class RdsDbAvailableTrigger(AwsBaseWaiterTrigger):
+    """
+    Trigger to wait asynchronously for a DB instance or cluster to be available.
+
+    :param db_identifier: The DB identifier for the DB instance or cluster to be polled.
+    :param waiter_delay: The amount of time in seconds to wait between attempts.
+    :param waiter_max_attempts: The maximum number of attempts to be made.
+    :param aws_conn_id: The Airflow connection used for AWS credentials.
+    :param region_name: AWS region where the DB is located, if different from the default one.
+    :param response: The response from the RdsHook, to be passed back to the operator.
+    :param db_type: The type of DB: instance or cluster.
+    """
+
+    def __init__(
+        self,
+        db_identifier: str,
+        waiter_delay: int,
+        waiter_max_attempts: int,
+        aws_conn_id: str,
+        region_name: str | None,
+        response: dict[str, Any],
+        db_type: RdsDbType,
+    ):
+        super().__init__(
+            serialized_fields={
+                "db_identifier": db_identifier,
+                "response": response,
+                "db_type": db_type,
+            },
+            waiter_name=f"db_{db_type.value}_available",
+            waiter_args={_waiter_arg[db_type]: db_identifier},
+            failure_message="Error while waiting for DB to be available",
+            status_message="DB initialization in progress",
+            status_queries=_status_paths[db_type],
+            return_key="response",
+            return_value=response,
+            waiter_delay=waiter_delay,
+            waiter_max_attempts=waiter_max_attempts,
+            aws_conn_id=aws_conn_id,
+            region_name=region_name,
+        )
+
+    def hook(self) -> AwsGenericHook:
+        return RdsHook(aws_conn_id=self.aws_conn_id, region_name=self.region_name)
+
+
+class RdsDbDeletedTrigger(AwsBaseWaiterTrigger):
+    """
+    Trigger to wait asynchronously for a DB instance or cluster to be deleted.
+
+    :param db_identifier: The DB identifier for the DB instance or cluster to be polled.
+    :param waiter_delay: The amount of time in seconds to wait between attempts.
+    :param waiter_max_attempts: The maximum number of attempts to be made.
+    :param aws_conn_id: The Airflow connection used for AWS credentials.
+    :param region_name: AWS region where the DB is located, if different from the default one.
+    :param response: The response from the RdsHook, to be passed back to the operator.
+    :param db_type: The type of DB: instance or cluster.
+    """
+
+    def __init__(
+        self,
+        db_identifier: str,
+        waiter_delay: int,
+        waiter_max_attempts: int,
+        aws_conn_id: str,
+        region_name: str | None,
+        response: dict[str, Any],
+        db_type: RdsDbType,
+    ):
+        super().__init__(
+            serialized_fields={
+                "db_identifier": db_identifier,
+                "response": response,
+                "db_type": db_type,
+            },
+            waiter_name=f"db_{db_type.value}_deleted",
+            waiter_args={_waiter_arg[db_type]: db_identifier},
+            failure_message="Error while deleting DB",
+            status_message="DB deletion in progress",
+            status_queries=_status_paths[db_type],
+            return_key="response",
+            return_value=response,
+            waiter_delay=waiter_delay,
+            waiter_max_attempts=waiter_max_attempts,
+            aws_conn_id=aws_conn_id,
+            region_name=region_name,
+        )
+
+    def hook(self) -> AwsGenericHook:
+        return RdsHook(aws_conn_id=self.aws_conn_id, region_name=self.region_name)
+
+
+class RdsDbStoppedTrigger(AwsBaseWaiterTrigger):
+    """
+    Trigger to wait asynchronously for a DB instance or cluster to be stopped.
+
+    :param db_identifier: The DB identifier for the DB instance or cluster to be polled.
+    :param waiter_delay: The amount of time in seconds to wait between attempts.
+    :param waiter_max_attempts: The maximum number of attempts to be made.
+    :param aws_conn_id: The Airflow connection used for AWS credentials.
+    :param region_name: AWS region where the DB is located, if different from the default one.
+    :param response: The response from the RdsHook, to be passed back to the operator.
+    :param db_type: The type of DB: instance or cluster.
+    """
+
+    def __init__(
+        self,
+        db_identifier: str,
+        waiter_delay: int,
+        waiter_max_attempts: int,
+        aws_conn_id: str,
+        region_name: str | None,
+        response: dict[str, Any],
+        db_type: RdsDbType,
+    ):

Review Comment:
   ```suggestion
       ) -> None:
   ```



##########
airflow/providers/amazon/aws/utils/rds.py:
##########
@@ -22,5 +22,5 @@
 class RdsDbType(Enum):
     """Only available types for the RDS."""
 
-    INSTANCE: str = "instance"
-    CLUSTER: str = "cluster"

Review Comment:
   Should we add the correct annotation here?



##########
airflow/providers/amazon/aws/triggers/rds.py:
##########
@@ -87,3 +93,151 @@ async def run(self):
                 status_args=["DBInstances[0].DBInstanceStatus"],
             )
         yield TriggerEvent({"status": "success", "response": self.response})
+
+
+_waiter_arg = {
+    RdsDbType.INSTANCE: "DBInstanceIdentifier",
+    RdsDbType.CLUSTER: "DBClusterIdentifier",
+}
+_status_paths = {
+    RdsDbType.INSTANCE: ["DBInstances[].DBInstanceStatus", "DBInstances[].StatusInfos"],
+    RdsDbType.CLUSTER: ["DBClusters[].Status"],
+}
+
+
+class RdsDbAvailableTrigger(AwsBaseWaiterTrigger):
+    """
+    Trigger to wait asynchronously for a DB instance or cluster to be available.
+
+    :param db_identifier: The DB identifier for the DB instance or cluster to be polled.
+    :param waiter_delay: The amount of time in seconds to wait between attempts.
+    :param waiter_max_attempts: The maximum number of attempts to be made.
+    :param aws_conn_id: The Airflow connection used for AWS credentials.
+    :param region_name: AWS region where the DB is located, if different from the default one.
+    :param response: The response from the RdsHook, to be passed back to the operator.
+    :param db_type: The type of DB: instance or cluster.
+    """
+
+    def __init__(
+        self,
+        db_identifier: str,
+        waiter_delay: int,
+        waiter_max_attempts: int,
+        aws_conn_id: str,
+        region_name: str | None,
+        response: dict[str, Any],
+        db_type: RdsDbType,
+    ):
+        super().__init__(
+            serialized_fields={
+                "db_identifier": db_identifier,
+                "response": response,
+                "db_type": db_type,
+            },
+            waiter_name=f"db_{db_type.value}_available",
+            waiter_args={_waiter_arg[db_type]: db_identifier},
+            failure_message="Error while waiting for DB to be available",
+            status_message="DB initialization in progress",
+            status_queries=_status_paths[db_type],
+            return_key="response",
+            return_value=response,
+            waiter_delay=waiter_delay,
+            waiter_max_attempts=waiter_max_attempts,
+            aws_conn_id=aws_conn_id,
+            region_name=region_name,
+        )
+
+    def hook(self) -> AwsGenericHook:
+        return RdsHook(aws_conn_id=self.aws_conn_id, region_name=self.region_name)
+
+
+class RdsDbDeletedTrigger(AwsBaseWaiterTrigger):
+    """
+    Trigger to wait asynchronously for a DB instance or cluster to be deleted.
+
+    :param db_identifier: The DB identifier for the DB instance or cluster to be polled.
+    :param waiter_delay: The amount of time in seconds to wait between attempts.
+    :param waiter_max_attempts: The maximum number of attempts to be made.
+    :param aws_conn_id: The Airflow connection used for AWS credentials.
+    :param region_name: AWS region where the DB is located, if different from the default one.
+    :param response: The response from the RdsHook, to be passed back to the operator.
+    :param db_type: The type of DB: instance or cluster.
+    """
+
+    def __init__(
+        self,
+        db_identifier: str,
+        waiter_delay: int,
+        waiter_max_attempts: int,
+        aws_conn_id: str,
+        region_name: str | None,
+        response: dict[str, Any],
+        db_type: RdsDbType,
+    ):

Review Comment:
   ```suggestion
       ) -> None:
   ```



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@airflow.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org