You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by el...@apache.org on 2023/06/27 08:18:59 UTC

[airflow] 01/01: Revert "add deferrable mode for `AthenaOperator` (#32110)"

This is an automated email from the ASF dual-hosted git repository.

eladkal pushed a commit to branch revert-32110-vandonr/athena
in repository https://gitbox.apache.org/repos/asf/airflow.git

commit 1bed1f9deee76897c86239a015d7f5332113c0e7
Author: eladkal <45...@users.noreply.github.com>
AuthorDate: Tue Jun 27 11:18:51 2023 +0300

    Revert "add deferrable mode for `AthenaOperator` (#32110)"
    
    This reverts commit 256438c3d6a80c989c68d2e0f3c8549108770f0e.
---
 airflow/providers/amazon/aws/operators/athena.py   | 20 +-----
 airflow/providers/amazon/aws/triggers/athena.py    | 76 ----------------------
 airflow/providers/amazon/provider.yaml             |  3 -
 .../providers/amazon/aws/operators/test_athena.py  | 12 ----
 tests/providers/amazon/aws/triggers/test_athena.py | 53 ---------------
 5 files changed, 1 insertion(+), 163 deletions(-)

diff --git a/airflow/providers/amazon/aws/operators/athena.py b/airflow/providers/amazon/aws/operators/athena.py
index 990f2ec414..612e563ce6 100644
--- a/airflow/providers/amazon/aws/operators/athena.py
+++ b/airflow/providers/amazon/aws/operators/athena.py
@@ -20,10 +20,8 @@ from __future__ import annotations
 from functools import cached_property
 from typing import TYPE_CHECKING, Any, Sequence
 
-from airflow import AirflowException
 from airflow.models import BaseOperator
 from airflow.providers.amazon.aws.hooks.athena import AthenaHook
-from airflow.providers.amazon.aws.triggers.athena import AthenaTrigger
 
 if TYPE_CHECKING:
     from airflow.utils.context import Context
@@ -71,7 +69,6 @@ class AthenaOperator(BaseOperator):
         sleep_time: int = 30,
         max_polling_attempts: int | None = None,
         log_query: bool = True,
-        deferrable: bool = False,
         **kwargs: Any,
     ) -> None:
         super().__init__(**kwargs)
@@ -84,10 +81,9 @@ class AthenaOperator(BaseOperator):
         self.query_execution_context = query_execution_context or {}
         self.result_configuration = result_configuration or {}
         self.sleep_time = sleep_time
-        self.max_polling_attempts = max_polling_attempts or 999999
+        self.max_polling_attempts = max_polling_attempts
         self.query_execution_id: str | None = None
         self.log_query: bool = log_query
-        self.deferrable = deferrable
 
     @cached_property
     def hook(self) -> AthenaHook:
@@ -105,15 +101,6 @@ class AthenaOperator(BaseOperator):
             self.client_request_token,
             self.workgroup,
         )
-
-        if self.deferrable:
-            self.defer(
-                trigger=AthenaTrigger(
-                    self.query_execution_id, self.sleep_time, self.max_polling_attempts, self.aws_conn_id
-                ),
-                method_name="execute_complete",
-            )
-        # implicit else:
         query_status = self.hook.poll_query_status(
             self.query_execution_id,
             max_polling_attempts=self.max_polling_attempts,
@@ -134,11 +121,6 @@ class AthenaOperator(BaseOperator):
 
         return self.query_execution_id
 
-    def execute_complete(self, context, event=None):
-        if event["status"] != "success":
-            raise AirflowException(f"Error while waiting for operation on cluster to complete: {event}")
-        return event["value"]
-
     def on_kill(self) -> None:
         """Cancel the submitted athena query."""
         if self.query_execution_id:
diff --git a/airflow/providers/amazon/aws/triggers/athena.py b/airflow/providers/amazon/aws/triggers/athena.py
deleted file mode 100644
index 780d9e9b98..0000000000
--- a/airflow/providers/amazon/aws/triggers/athena.py
+++ /dev/null
@@ -1,76 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements.  See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership.  The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License.  You may obtain a copy of the License at
-#
-#   http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied.  See the License for the
-# specific language governing permissions and limitations
-# under the License.
-from __future__ import annotations
-
-from typing import Any
-
-from airflow.providers.amazon.aws.hooks.athena import AthenaHook
-from airflow.providers.amazon.aws.utils.waiter_with_logging import async_wait
-from airflow.triggers.base import BaseTrigger, TriggerEvent
-
-
-class AthenaTrigger(BaseTrigger):
-    """
-    Trigger for RedshiftCreateClusterOperator.
-
-    The trigger will asynchronously poll the boto3 API and wait for the
-    Redshift cluster to be in the `available` state.
-
-    :param query_execution_id:  ID of the Athena query execution to watch
-    :param poll_interval: The amount of time in seconds to wait between attempts.
-    :param max_attempt: The maximum number of attempts to be made.
-    :param aws_conn_id: The Airflow connection used for AWS credentials.
-    """
-
-    def __init__(
-        self,
-        query_execution_id: str,
-        poll_interval: int,
-        max_attempt: int,
-        aws_conn_id: str,
-    ):
-        self.query_execution_id = query_execution_id
-        self.poll_interval = poll_interval
-        self.max_attempt = max_attempt
-        self.aws_conn_id = aws_conn_id
-
-    def serialize(self) -> tuple[str, dict[str, Any]]:
-        return (
-            self.__class__.__module__ + "." + self.__class__.__qualname__,
-            {
-                "query_execution_id": str(self.query_execution_id),
-                "poll_interval": str(self.poll_interval),
-                "max_attempt": str(self.max_attempt),
-                "aws_conn_id": str(self.aws_conn_id),
-            },
-        )
-
-    async def run(self):
-        hook = AthenaHook(self.aws_conn_id)
-        async with hook.async_conn as client:
-            waiter = hook.get_waiter("query_complete", deferrable=True, client=client)
-            await async_wait(
-                waiter=waiter,
-                waiter_delay=self.poll_interval,
-                max_attempts=self.max_attempt,
-                args={"QueryExecutionId": self.query_execution_id},
-                failure_message=f"Error while waiting for query {self.query_execution_id} to complete",
-                status_message=f"Query execution id: {self.query_execution_id}, "
-                "Query is still in non-terminal state",
-                status_args=["QueryExecution.Status.State"],
-            )
-        yield TriggerEvent({"status": "success", "value": self.query_execution_id})
diff --git a/airflow/providers/amazon/provider.yaml b/airflow/providers/amazon/provider.yaml
index e4f16ce398..5439f9c8cb 100644
--- a/airflow/providers/amazon/provider.yaml
+++ b/airflow/providers/amazon/provider.yaml
@@ -515,9 +515,6 @@ hooks:
       - airflow.providers.amazon.aws.hooks.appflow
 
 triggers:
-  - integration-name: Amazon Athena
-    python-modules:
-      - airflow.providers.amazon.aws.triggers.athena
   - integration-name: AWS Batch
     python-modules:
       - airflow.providers.amazon.aws.triggers.batch
diff --git a/tests/providers/amazon/aws/operators/test_athena.py b/tests/providers/amazon/aws/operators/test_athena.py
index 9e52852520..cfc7869768 100644
--- a/tests/providers/amazon/aws/operators/test_athena.py
+++ b/tests/providers/amazon/aws/operators/test_athena.py
@@ -20,11 +20,9 @@ from unittest import mock
 
 import pytest
 
-from airflow.exceptions import TaskDeferred
 from airflow.models import DAG, DagRun, TaskInstance
 from airflow.providers.amazon.aws.hooks.athena import AthenaHook
 from airflow.providers.amazon.aws.operators.athena import AthenaOperator
-from airflow.providers.amazon.aws.triggers.athena import AthenaTrigger
 from airflow.utils import timezone
 from airflow.utils.timezone import datetime
 
@@ -160,13 +158,3 @@ class TestAthenaOperator:
         ti.dag_run = dag_run
 
         assert self.athena.execute(ti.get_template_context()) == ATHENA_QUERY_ID
-
-    @mock.patch.object(AthenaHook, "run_query", return_value=ATHENA_QUERY_ID)
-    def test_is_deferred(self, mock_run_query):
-        self.athena.deferrable = True
-
-        with pytest.raises(TaskDeferred) as deferred:
-            self.athena.execute(None)
-
-        assert isinstance(deferred.value.trigger, AthenaTrigger)
-        assert deferred.value.trigger.query_execution_id == ATHENA_QUERY_ID
diff --git a/tests/providers/amazon/aws/triggers/test_athena.py b/tests/providers/amazon/aws/triggers/test_athena.py
deleted file mode 100644
index 04e601f439..0000000000
--- a/tests/providers/amazon/aws/triggers/test_athena.py
+++ /dev/null
@@ -1,53 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements.  See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership.  The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License.  You may obtain a copy of the License at
-#
-#   http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied.  See the License for the
-# specific language governing permissions and limitations
-# under the License.
-from __future__ import annotations
-
-from unittest import mock
-from unittest.mock import AsyncMock
-
-import pytest
-from botocore.exceptions import WaiterError
-
-from airflow.providers.amazon.aws.hooks.athena import AthenaHook
-from airflow.providers.amazon.aws.triggers.athena import AthenaTrigger
-
-
-class TestAthenaTrigger:
-    @pytest.mark.asyncio
-    @mock.patch.object(AthenaHook, "get_waiter")
-    @mock.patch.object(AthenaHook, "async_conn")  # LatestBoto step of CI fails without this
-    async def test_run_with_error(self, conn_mock, waiter_mock):
-        waiter_mock.side_effect = WaiterError("name", "reason", {})
-
-        trigger = AthenaTrigger("query_id", 0, 5, None)
-
-        with pytest.raises(WaiterError):
-            generator = trigger.run()
-            await generator.asend(None)
-
-    @pytest.mark.asyncio
-    @mock.patch.object(AthenaHook, "get_waiter")
-    @mock.patch.object(AthenaHook, "async_conn")  # LatestBoto step of CI fails without this
-    async def test_run_success(self, conn_mock, waiter_mock):
-        waiter_mock().wait = AsyncMock()
-        trigger = AthenaTrigger("my_query_id", 0, 5, None)
-
-        generator = trigger.run()
-        event = await generator.asend(None)
-
-        assert event.payload["status"] == "success"
-        assert event.payload["value"] == "my_query_id"