You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by el...@apache.org on 2023/06/27 08:18:59 UTC
[airflow] 01/01: Revert "add deferrable mode for `AthenaOperator` (#32110)"
This is an automated email from the ASF dual-hosted git repository.
eladkal pushed a commit to branch revert-32110-vandonr/athena
in repository https://gitbox.apache.org/repos/asf/airflow.git
commit 1bed1f9deee76897c86239a015d7f5332113c0e7
Author: eladkal <45...@users.noreply.github.com>
AuthorDate: Tue Jun 27 11:18:51 2023 +0300
Revert "add deferrable mode for `AthenaOperator` (#32110)"
This reverts commit 256438c3d6a80c989c68d2e0f3c8549108770f0e.
---
airflow/providers/amazon/aws/operators/athena.py | 20 +-----
airflow/providers/amazon/aws/triggers/athena.py | 76 ----------------------
airflow/providers/amazon/provider.yaml | 3 -
.../providers/amazon/aws/operators/test_athena.py | 12 ----
tests/providers/amazon/aws/triggers/test_athena.py | 53 ---------------
5 files changed, 1 insertion(+), 163 deletions(-)
diff --git a/airflow/providers/amazon/aws/operators/athena.py b/airflow/providers/amazon/aws/operators/athena.py
index 990f2ec414..612e563ce6 100644
--- a/airflow/providers/amazon/aws/operators/athena.py
+++ b/airflow/providers/amazon/aws/operators/athena.py
@@ -20,10 +20,8 @@ from __future__ import annotations
from functools import cached_property
from typing import TYPE_CHECKING, Any, Sequence
-from airflow import AirflowException
from airflow.models import BaseOperator
from airflow.providers.amazon.aws.hooks.athena import AthenaHook
-from airflow.providers.amazon.aws.triggers.athena import AthenaTrigger
if TYPE_CHECKING:
from airflow.utils.context import Context
@@ -71,7 +69,6 @@ class AthenaOperator(BaseOperator):
sleep_time: int = 30,
max_polling_attempts: int | None = None,
log_query: bool = True,
- deferrable: bool = False,
**kwargs: Any,
) -> None:
super().__init__(**kwargs)
@@ -84,10 +81,9 @@ class AthenaOperator(BaseOperator):
self.query_execution_context = query_execution_context or {}
self.result_configuration = result_configuration or {}
self.sleep_time = sleep_time
- self.max_polling_attempts = max_polling_attempts or 999999
+ self.max_polling_attempts = max_polling_attempts
self.query_execution_id: str | None = None
self.log_query: bool = log_query
- self.deferrable = deferrable
@cached_property
def hook(self) -> AthenaHook:
@@ -105,15 +101,6 @@ class AthenaOperator(BaseOperator):
self.client_request_token,
self.workgroup,
)
-
- if self.deferrable:
- self.defer(
- trigger=AthenaTrigger(
- self.query_execution_id, self.sleep_time, self.max_polling_attempts, self.aws_conn_id
- ),
- method_name="execute_complete",
- )
- # implicit else:
query_status = self.hook.poll_query_status(
self.query_execution_id,
max_polling_attempts=self.max_polling_attempts,
@@ -134,11 +121,6 @@ class AthenaOperator(BaseOperator):
return self.query_execution_id
- def execute_complete(self, context, event=None):
- if event["status"] != "success":
- raise AirflowException(f"Error while waiting for operation on cluster to complete: {event}")
- return event["value"]
-
def on_kill(self) -> None:
"""Cancel the submitted athena query."""
if self.query_execution_id:
diff --git a/airflow/providers/amazon/aws/triggers/athena.py b/airflow/providers/amazon/aws/triggers/athena.py
deleted file mode 100644
index 780d9e9b98..0000000000
--- a/airflow/providers/amazon/aws/triggers/athena.py
+++ /dev/null
@@ -1,76 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-from __future__ import annotations
-
-from typing import Any
-
-from airflow.providers.amazon.aws.hooks.athena import AthenaHook
-from airflow.providers.amazon.aws.utils.waiter_with_logging import async_wait
-from airflow.triggers.base import BaseTrigger, TriggerEvent
-
-
-class AthenaTrigger(BaseTrigger):
- """
- Trigger for RedshiftCreateClusterOperator.
-
- The trigger will asynchronously poll the boto3 API and wait for the
- Redshift cluster to be in the `available` state.
-
- :param query_execution_id: ID of the Athena query execution to watch
- :param poll_interval: The amount of time in seconds to wait between attempts.
- :param max_attempt: The maximum number of attempts to be made.
- :param aws_conn_id: The Airflow connection used for AWS credentials.
- """
-
- def __init__(
- self,
- query_execution_id: str,
- poll_interval: int,
- max_attempt: int,
- aws_conn_id: str,
- ):
- self.query_execution_id = query_execution_id
- self.poll_interval = poll_interval
- self.max_attempt = max_attempt
- self.aws_conn_id = aws_conn_id
-
- def serialize(self) -> tuple[str, dict[str, Any]]:
- return (
- self.__class__.__module__ + "." + self.__class__.__qualname__,
- {
- "query_execution_id": str(self.query_execution_id),
- "poll_interval": str(self.poll_interval),
- "max_attempt": str(self.max_attempt),
- "aws_conn_id": str(self.aws_conn_id),
- },
- )
-
- async def run(self):
- hook = AthenaHook(self.aws_conn_id)
- async with hook.async_conn as client:
- waiter = hook.get_waiter("query_complete", deferrable=True, client=client)
- await async_wait(
- waiter=waiter,
- waiter_delay=self.poll_interval,
- max_attempts=self.max_attempt,
- args={"QueryExecutionId": self.query_execution_id},
- failure_message=f"Error while waiting for query {self.query_execution_id} to complete",
- status_message=f"Query execution id: {self.query_execution_id}, "
- "Query is still in non-terminal state",
- status_args=["QueryExecution.Status.State"],
- )
- yield TriggerEvent({"status": "success", "value": self.query_execution_id})
diff --git a/airflow/providers/amazon/provider.yaml b/airflow/providers/amazon/provider.yaml
index e4f16ce398..5439f9c8cb 100644
--- a/airflow/providers/amazon/provider.yaml
+++ b/airflow/providers/amazon/provider.yaml
@@ -515,9 +515,6 @@ hooks:
- airflow.providers.amazon.aws.hooks.appflow
triggers:
- - integration-name: Amazon Athena
- python-modules:
- - airflow.providers.amazon.aws.triggers.athena
- integration-name: AWS Batch
python-modules:
- airflow.providers.amazon.aws.triggers.batch
diff --git a/tests/providers/amazon/aws/operators/test_athena.py b/tests/providers/amazon/aws/operators/test_athena.py
index 9e52852520..cfc7869768 100644
--- a/tests/providers/amazon/aws/operators/test_athena.py
+++ b/tests/providers/amazon/aws/operators/test_athena.py
@@ -20,11 +20,9 @@ from unittest import mock
import pytest
-from airflow.exceptions import TaskDeferred
from airflow.models import DAG, DagRun, TaskInstance
from airflow.providers.amazon.aws.hooks.athena import AthenaHook
from airflow.providers.amazon.aws.operators.athena import AthenaOperator
-from airflow.providers.amazon.aws.triggers.athena import AthenaTrigger
from airflow.utils import timezone
from airflow.utils.timezone import datetime
@@ -160,13 +158,3 @@ class TestAthenaOperator:
ti.dag_run = dag_run
assert self.athena.execute(ti.get_template_context()) == ATHENA_QUERY_ID
-
- @mock.patch.object(AthenaHook, "run_query", return_value=ATHENA_QUERY_ID)
- def test_is_deferred(self, mock_run_query):
- self.athena.deferrable = True
-
- with pytest.raises(TaskDeferred) as deferred:
- self.athena.execute(None)
-
- assert isinstance(deferred.value.trigger, AthenaTrigger)
- assert deferred.value.trigger.query_execution_id == ATHENA_QUERY_ID
diff --git a/tests/providers/amazon/aws/triggers/test_athena.py b/tests/providers/amazon/aws/triggers/test_athena.py
deleted file mode 100644
index 04e601f439..0000000000
--- a/tests/providers/amazon/aws/triggers/test_athena.py
+++ /dev/null
@@ -1,53 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-from __future__ import annotations
-
-from unittest import mock
-from unittest.mock import AsyncMock
-
-import pytest
-from botocore.exceptions import WaiterError
-
-from airflow.providers.amazon.aws.hooks.athena import AthenaHook
-from airflow.providers.amazon.aws.triggers.athena import AthenaTrigger
-
-
-class TestAthenaTrigger:
- @pytest.mark.asyncio
- @mock.patch.object(AthenaHook, "get_waiter")
- @mock.patch.object(AthenaHook, "async_conn") # LatestBoto step of CI fails without this
- async def test_run_with_error(self, conn_mock, waiter_mock):
- waiter_mock.side_effect = WaiterError("name", "reason", {})
-
- trigger = AthenaTrigger("query_id", 0, 5, None)
-
- with pytest.raises(WaiterError):
- generator = trigger.run()
- await generator.asend(None)
-
- @pytest.mark.asyncio
- @mock.patch.object(AthenaHook, "get_waiter")
- @mock.patch.object(AthenaHook, "async_conn") # LatestBoto step of CI fails without this
- async def test_run_success(self, conn_mock, waiter_mock):
- waiter_mock().wait = AsyncMock()
- trigger = AthenaTrigger("my_query_id", 0, 5, None)
-
- generator = trigger.run()
- event = await generator.asend(None)
-
- assert event.payload["status"] == "success"
- assert event.payload["value"] == "my_query_id"