You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by po...@apache.org on 2024/02/15 13:58:43 UTC
(airflow) branch main updated: Revert "KPO Maintain backward compatibility for execute_complete and trigger run method (#37363)" (#37446)
This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 0be6430938 Revert "KPO Maintain backward compatibility for execute_complete and trigger run method (#37363)" (#37446)
0be6430938 is described below
commit 0be643093879e106f7ee1e41c155954edd14398f
Author: Jarek Potiuk <ja...@potiuk.com>
AuthorDate: Thu Feb 15 14:58:36 2024 +0100
Revert "KPO Maintain backward compatibility for execute_complete and trigger run method (#37363)" (#37446)
This reverts commit 0640e6d595c01dd96f2b90812a546bc091f87743.
---
airflow/providers/cncf/kubernetes/operators/pod.py | 150 ++++++++++++---------
airflow/providers/cncf/kubernetes/triggers/pod.py | 70 +++-------
.../cncf/kubernetes/operators/test_pod.py | 34 ++---
.../providers/cncf/kubernetes/triggers/test_pod.py | 92 ++++++-------
.../cloud/triggers/test_kubernetes_engine.py | 51 ++++---
5 files changed, 189 insertions(+), 208 deletions(-)
diff --git a/airflow/providers/cncf/kubernetes/operators/pod.py b/airflow/providers/cncf/kubernetes/operators/pod.py
index 61442a6014..73389f4038 100644
--- a/airflow/providers/cncf/kubernetes/operators/pod.py
+++ b/airflow/providers/cncf/kubernetes/operators/pod.py
@@ -18,7 +18,6 @@
from __future__ import annotations
-import datetime
import json
import logging
import re
@@ -31,7 +30,6 @@ from functools import cached_property
from typing import TYPE_CHECKING, Any, Callable, Iterable, Sequence
import kubernetes
-from deprecated import deprecated
from kubernetes.client import CoreV1Api, V1Pod, models as k8s
from kubernetes.stream import stream
from urllib3.exceptions import HTTPError
@@ -70,6 +68,7 @@ from airflow.providers.cncf.kubernetes.utils.pod_manager import (
EMPTY_XCOM_RESULT,
OnFinishAction,
PodLaunchFailedException,
+ PodLaunchTimeoutException,
PodManager,
PodNotFoundException,
PodOperatorHookProtocol,
@@ -80,6 +79,7 @@ from airflow.providers.cncf.kubernetes.utils.pod_manager import (
from airflow.settings import pod_mutation_hook
from airflow.utils import yaml
from airflow.utils.helpers import prune_dict, validate_key
+from airflow.utils.timezone import utcnow
from airflow.version import version as airflow_version
if TYPE_CHECKING:
@@ -656,7 +656,7 @@ class KubernetesPodOperator(BaseOperator):
def invoke_defer_method(self, last_log_time: DateTime | None = None):
"""Redefine triggers which are being used in child classes."""
- trigger_start_time = datetime.datetime.now(tz=datetime.timezone.utc)
+ trigger_start_time = utcnow()
self.defer(
trigger=KubernetesPodTrigger(
pod_name=self.pod.metadata.name, # type: ignore[union-attr]
@@ -678,87 +678,117 @@ class KubernetesPodOperator(BaseOperator):
method_name="trigger_reentry",
)
+ @staticmethod
+ def raise_for_trigger_status(event: dict[str, Any]) -> None:
+ """Raise exception if pod is not in expected state."""
+ if event["status"] == "error":
+ error_type = event["error_type"]
+ description = event["description"]
+ if error_type == "PodLaunchTimeoutException":
+ raise PodLaunchTimeoutException(description)
+ else:
+ raise AirflowException(description)
+
def trigger_reentry(self, context: Context, event: dict[str, Any]) -> Any:
"""
Point of re-entry from trigger.
- If ``logging_interval`` is None, then at this point, the pod should be done, and we'll just fetch
+ If ``logging_interval`` is None, then at this point the pod should be done and we'll just fetch
the logs and exit.
- If ``logging_interval`` is not None, it could be that the pod is still running, and we'll just
+ If ``logging_interval`` is not None, it could be that the pod is still running and we'll just
grab the latest logs and defer back to the trigger again.
"""
- self.pod = None
+ remote_pod = None
try:
- pod_name = event["name"]
- pod_namespace = event["namespace"]
+ self.pod_request_obj = self.build_pod_request_obj(context)
+ self.pod = self.find_pod(
+ namespace=self.namespace or self.pod_request_obj.metadata.namespace,
+ context=context,
+ )
- self.pod = self.hook.get_pod(pod_name, pod_namespace)
+ # we try to find pod before possibly raising so that on_kill will have `pod` attr
+ self.raise_for_trigger_status(event)
if not self.pod:
raise PodNotFoundException("Could not find pod after resuming from deferral")
- if self.callbacks and event["status"] != "running":
- self.callbacks.on_operator_resuming(
- pod=self.pod, event=event, client=self.client, mode=ExecutionMode.SYNC
+ if self.get_logs:
+ last_log_time = event and event.get("last_log_time")
+ if last_log_time:
+ self.log.info("Resuming logs read from time %r", last_log_time)
+ pod_log_status = self.pod_manager.fetch_container_logs(
+ pod=self.pod,
+ container_name=self.BASE_CONTAINER_NAME,
+ follow=self.logging_interval is None,
+ since_time=last_log_time,
)
+ if pod_log_status.running:
+ self.log.info("Container still running; deferring again.")
+ self.invoke_defer_method(pod_log_status.last_log_time)
+
+ if self.do_xcom_push:
+ result = self.extract_xcom(pod=self.pod)
+ remote_pod = self.pod_manager.await_pod_completion(self.pod)
+ except TaskDeferred:
+ raise
+ except Exception:
+ self.cleanup(
+ pod=self.pod or self.pod_request_obj,
+ remote_pod=remote_pod,
+ )
+ raise
+ self.cleanup(
+ pod=self.pod or self.pod_request_obj,
+ remote_pod=remote_pod,
+ )
+ if self.do_xcom_push:
+ return result
+ def execute_complete(self, context: Context, event: dict, **kwargs):
+ self.log.debug("Triggered with event: %s", event)
+ pod = None
+ try:
+ pod = self.hook.get_pod(
+ event["name"],
+ event["namespace"],
+ )
+ if self.callbacks:
+ self.callbacks.on_operator_resuming(
+ pod=pod, event=event, client=self.client, mode=ExecutionMode.SYNC
+ )
if event["status"] in ("error", "failed", "timeout"):
+ # fetch some logs when pod is failed
+ if self.get_logs:
+ self.write_logs(pod)
+ if "stack_trace" in event:
+ message = f"{event['message']}\n{event['stack_trace']}"
+ else:
+ message = event["message"]
if self.do_xcom_push:
- _ = self.extract_xcom(pod=self.pod)
-
- message = event.get("stack_trace", event["message"])
+ # In the event of base container failure, we need to kill the xcom sidecar.
+ # We disregard xcom output and do that here
+ _ = self.extract_xcom(pod=pod)
raise AirflowException(message)
-
- elif event["status"] == "running":
+ elif event["status"] == "success":
+ # fetch some logs when pod is executed successfully
if self.get_logs:
- last_log_time = event.get("last_log_time")
- self.log.info("Resuming logs read from time %r", last_log_time)
-
- pod_log_status = self.pod_manager.fetch_container_logs(
- pod=self.pod,
- container_name=self.BASE_CONTAINER_NAME,
- follow=self.logging_interval is None,
- since_time=last_log_time,
- )
+ self.write_logs(pod)
- if pod_log_status.running:
- self.log.info("Container still running; deferring again.")
- self.invoke_defer_method(pod_log_status.last_log_time)
- else:
- self.invoke_defer_method()
-
- elif event["status"] == "success":
if self.do_xcom_push:
- xcom_sidecar_output = self.extract_xcom(pod=self.pod)
+ xcom_sidecar_output = self.extract_xcom(pod=pod)
return xcom_sidecar_output
- return
- except TaskDeferred:
- raise
finally:
- self._clean(event)
-
- def _clean(self, event: dict[str, Any]):
- if event["status"] == "running":
- return
- if self.get_logs:
- self.write_logs(self.pod)
- istio_enabled = self.is_istio_enabled(self.pod)
- # Skip await_pod_completion when the event is 'timeout' due to the pod can hang
- # on the ErrImagePull or ContainerCreating step and it will never complete
- if event["status"] != "timeout":
- self.pod = self.pod_manager.await_pod_completion(
- self.pod, istio_enabled, self.base_container_name
- )
- if self.pod is not None:
- self.post_complete_action(
- pod=self.pod,
- remote_pod=self.pod,
- )
-
- @deprecated(reason="use `trigger_reentry` instead.", category=AirflowProviderDeprecationWarning)
- def execute_complete(self, context: Context, event: dict, **kwargs):
- self.trigger_reentry(context=context, event=event)
+ istio_enabled = self.is_istio_enabled(pod)
+ # Skip await_pod_completion when the event is 'timeout' due to the pod can hang
+ # on the ErrImagePull or ContainerCreating step and it will never complete
+ if event["status"] != "timeout":
+ pod = self.pod_manager.await_pod_completion(pod, istio_enabled, self.base_container_name)
+ if pod is not None:
+ self.post_complete_action(
+ pod=pod,
+ remote_pod=pod,
+ )
def write_logs(self, pod: k8s.V1Pod):
try:
diff --git a/airflow/providers/cncf/kubernetes/triggers/pod.py b/airflow/providers/cncf/kubernetes/triggers/pod.py
index c9b1e62226..e34a73f146 100644
--- a/airflow/providers/cncf/kubernetes/triggers/pod.py
+++ b/airflow/providers/cncf/kubernetes/triggers/pod.py
@@ -30,8 +30,10 @@ from airflow.providers.cncf.kubernetes.utils.pod_manager import (
OnFinishAction,
PodLaunchTimeoutException,
PodPhase,
+ container_is_running,
)
from airflow.triggers.base import BaseTrigger, TriggerEvent
+from airflow.utils import timezone
if TYPE_CHECKING:
from kubernetes_asyncio.client.models import V1Pod
@@ -158,49 +160,22 @@ class KubernetesPodTrigger(BaseTrigger):
self.log.info("Checking pod %r in namespace %r.", self.pod_name, self.pod_namespace)
try:
state = await self._wait_for_pod_start()
- if state == ContainerState.TERMINATED:
+ if state in PodPhase.terminal_states:
event = TriggerEvent(
- {
- "status": "success",
- "namespace": self.pod_namespace,
- "name": self.pod_name,
- "message": "All containers inside pod have started successfully.",
- }
- )
- elif state == ContainerState.FAILED:
- event = TriggerEvent(
- {
- "status": "failed",
- "namespace": self.pod_namespace,
- "name": self.pod_name,
- "message": "pod failed",
- }
+ {"status": "done", "namespace": self.pod_namespace, "pod_name": self.pod_name}
)
else:
event = await self._wait_for_container_completion()
yield event
- return
- except PodLaunchTimeoutException as e:
- message = self._format_exception_description(e)
- yield TriggerEvent(
- {
- "name": self.pod_name,
- "namespace": self.pod_namespace,
- "status": "timeout",
- "message": message,
- }
- )
except Exception as e:
+ description = self._format_exception_description(e)
yield TriggerEvent(
{
- "name": self.pod_name,
- "namespace": self.pod_namespace,
"status": "error",
- "message": str(e),
- "stack_trace": traceback.format_exc(),
+ "error_type": e.__class__.__name__,
+ "description": description,
}
)
- return
def _format_exception_description(self, exc: Exception) -> Any:
if isinstance(exc, PodLaunchTimeoutException):
@@ -214,13 +189,14 @@ class KubernetesPodTrigger(BaseTrigger):
description += f"\ntrigger traceback:\n{curr_traceback}"
return description
- async def _wait_for_pod_start(self) -> ContainerState:
+ async def _wait_for_pod_start(self) -> Any:
"""Loops until pod phase leaves ``PENDING`` If timeout is reached, throws error."""
- delta = datetime.datetime.now(tz=datetime.timezone.utc) - self.trigger_start_time
- while self.startup_timeout >= delta.total_seconds():
+ start_time = timezone.utcnow()
+ timeout_end = start_time + datetime.timedelta(seconds=self.startup_timeout)
+ while timeout_end > timezone.utcnow():
pod = await self.hook.get_pod(self.pod_name, self.pod_namespace)
if not pod.status.phase == "Pending":
- return self.define_container_state(pod)
+ return pod.status.phase
self.log.info("Still waiting for pod to start. The pod state is %s", pod.status.phase)
await asyncio.sleep(self.poll_interval)
raise PodLaunchTimeoutException("Pod did not leave 'Pending' phase within specified timeout")
@@ -232,30 +208,18 @@ class KubernetesPodTrigger(BaseTrigger):
Waits until container is no longer in running state. If trigger is configured with a logging period,
then will emit an event to resume the task for the purpose of fetching more logs.
"""
- time_begin = datetime.datetime.now(tz=datetime.timezone.utc)
+ time_begin = timezone.utcnow()
time_get_more_logs = None
if self.logging_interval is not None:
time_get_more_logs = time_begin + datetime.timedelta(seconds=self.logging_interval)
while True:
pod = await self.hook.get_pod(self.pod_name, self.pod_namespace)
- container_state = self.define_container_state(pod)
- if container_state == ContainerState.TERMINATED:
- return TriggerEvent(
- {"status": "success", "namespace": self.pod_namespace, "name": self.pod_name}
- )
- elif container_state == ContainerState.FAILED:
- return TriggerEvent(
- {"status": "failed", "namespace": self.pod_namespace, "name": self.pod_name}
- )
- if time_get_more_logs and datetime.datetime.now(tz=datetime.timezone.utc) > time_get_more_logs:
+ if not container_is_running(pod=pod, container_name=self.base_container_name):
return TriggerEvent(
- {
- "status": "running",
- "last_log_time": self.last_log_time,
- "namespace": self.pod_namespace,
- "name": self.pod_name,
- }
+ {"status": "done", "namespace": self.pod_namespace, "pod_name": self.pod_name}
)
+ if time_get_more_logs and timezone.utcnow() > time_get_more_logs:
+ return TriggerEvent({"status": "running", "last_log_time": self.last_log_time})
await asyncio.sleep(self.poll_interval)
def _get_async_hook(self) -> AsyncKubernetesHook:
diff --git a/tests/providers/cncf/kubernetes/operators/test_pod.py b/tests/providers/cncf/kubernetes/operators/test_pod.py
index faa21eb7d7..c27cd23146 100644
--- a/tests/providers/cncf/kubernetes/operators/test_pod.py
+++ b/tests/providers/cncf/kubernetes/operators/test_pod.py
@@ -35,6 +35,7 @@ from airflow.providers.cncf.kubernetes.operators.pod import KubernetesPodOperato
from airflow.providers.cncf.kubernetes.secret import Secret
from airflow.providers.cncf.kubernetes.triggers.pod import KubernetesPodTrigger
from airflow.providers.cncf.kubernetes.utils.pod_manager import (
+ PodLaunchTimeoutException,
PodLoggingStatus,
PodPhase,
)
@@ -1972,39 +1973,41 @@ class TestKubernetesPodOperatorAsync:
with pytest.raises(AirflowException, match=expect_match):
k.cleanup(pod, pod)
- @mock.patch(f"{HOOK_CLASS}.get_pod")
+ @mock.patch(
+ "airflow.providers.cncf.kubernetes.operators.pod.KubernetesPodOperator.raise_for_trigger_status"
+ )
+ @mock.patch("airflow.providers.cncf.kubernetes.operators.pod.KubernetesPodOperator.find_pod")
@mock.patch("airflow.providers.cncf.kubernetes.utils.pod_manager.PodManager.await_pod_completion")
@mock.patch("airflow.providers.cncf.kubernetes.utils.pod_manager.PodManager.fetch_container_logs")
def test_get_logs_running(
self,
fetch_container_logs,
await_pod_completion,
- get_pod,
+ find_pod,
+ raise_for_trigger_status,
):
"""When logs fetch exits with status running, raise task deferred"""
pod = MagicMock()
- get_pod.return_value = pod
+ find_pod.return_value = pod
op = KubernetesPodOperator(task_id="test_task", name="test-pod", get_logs=True)
await_pod_completion.return_value = None
fetch_container_logs.return_value = PodLoggingStatus(True, None)
with pytest.raises(TaskDeferred):
- op.trigger_reentry(
- create_context(op),
- event={"name": TEST_NAME, "namespace": TEST_NAMESPACE, "status": "running"},
- )
+ op.trigger_reentry(create_context(op), None)
fetch_container_logs.is_called_with(pod, "base")
@mock.patch("airflow.providers.cncf.kubernetes.operators.pod.KubernetesPodOperator.cleanup")
+ @mock.patch(
+ "airflow.providers.cncf.kubernetes.operators.pod.KubernetesPodOperator.raise_for_trigger_status"
+ )
@mock.patch("airflow.providers.cncf.kubernetes.operators.pod.KubernetesPodOperator.find_pod")
@mock.patch("airflow.providers.cncf.kubernetes.utils.pod_manager.PodManager.fetch_container_logs")
- def test_get_logs_not_running(self, fetch_container_logs, find_pod, cleanup):
+ def test_get_logs_not_running(self, fetch_container_logs, find_pod, raise_for_trigger_status, cleanup):
pod = MagicMock()
find_pod.return_value = pod
op = KubernetesPodOperator(task_id="test_task", name="test-pod", get_logs=True)
fetch_container_logs.return_value = PodLoggingStatus(False, None)
- op.trigger_reentry(
- create_context(op), event={"name": TEST_NAME, "namespace": TEST_NAMESPACE, "status": "success"}
- )
+ op.trigger_reentry(create_context(op), None)
fetch_container_logs.is_called_with(pod, "base")
@mock.patch("airflow.providers.cncf.kubernetes.operators.pod.KubernetesPodOperator.cleanup")
@@ -2013,15 +2016,14 @@ class TestKubernetesPodOperatorAsync:
"""Assert that trigger_reentry raise exception in case of error"""
find_pod.return_value = MagicMock()
op = KubernetesPodOperator(task_id="test_task", name="test-pod", get_logs=True)
- with pytest.raises(AirflowException):
+ with pytest.raises(PodLaunchTimeoutException):
context = create_context(op)
op.trigger_reentry(
context,
{
- "status": "timeout",
- "message": "any message",
- "name": TEST_NAME,
- "namespace": TEST_NAMESPACE,
+ "status": "error",
+ "error_type": "PodLaunchTimeoutException",
+ "description": "any message",
},
)
diff --git a/tests/providers/cncf/kubernetes/triggers/test_pod.py b/tests/providers/cncf/kubernetes/triggers/test_pod.py
index bed52811fc..d12100e4e3 100644
--- a/tests/providers/cncf/kubernetes/triggers/test_pod.py
+++ b/tests/providers/cncf/kubernetes/triggers/test_pod.py
@@ -122,10 +122,9 @@ class TestKubernetesPodTrigger:
expected_event = TriggerEvent(
{
- "status": "success",
- "namespace": "default",
- "name": "test-pod-name",
- "message": "All containers inside pod have started successfully.",
+ "pod_name": POD_NAME,
+ "namespace": NAMESPACE,
+ "status": "done",
}
)
actual_event = await trigger.run().asend(None)
@@ -133,11 +132,16 @@ class TestKubernetesPodTrigger:
assert actual_event == expected_event
@pytest.mark.asyncio
- @mock.patch(f"{TRIGGER_PATH}.define_container_state")
+ @mock.patch("airflow.providers.cncf.kubernetes.triggers.pod.container_is_running")
+ @mock.patch("airflow.providers.cncf.kubernetes.hooks.kubernetes.AsyncKubernetesHook.get_pod")
+ @mock.patch(f"{TRIGGER_PATH}._wait_for_pod_start")
@mock.patch(f"{TRIGGER_PATH}.hook")
- async def test_run_loop_return_waiting_event(self, mock_hook, mock_method, trigger, caplog):
+ async def test_run_loop_return_waiting_event(
+ self, mock_hook, mock_method, mock_get_pod, mock_container_is_running, trigger, caplog
+ ):
mock_hook.get_pod.return_value = self._mock_pod_result(mock.MagicMock())
mock_method.return_value = ContainerState.WAITING
+ mock_container_is_running.return_value = True
caplog.set_level(logging.INFO)
@@ -149,11 +153,16 @@ class TestKubernetesPodTrigger:
assert f"Sleeping for {POLL_INTERVAL} seconds."
@pytest.mark.asyncio
- @mock.patch(f"{TRIGGER_PATH}.define_container_state")
+ @mock.patch("airflow.providers.cncf.kubernetes.triggers.pod.container_is_running")
+ @mock.patch("airflow.providers.cncf.kubernetes.hooks.kubernetes.AsyncKubernetesHook.get_pod")
+ @mock.patch(f"{TRIGGER_PATH}._wait_for_pod_start")
@mock.patch(f"{TRIGGER_PATH}.hook")
- async def test_run_loop_return_running_event(self, mock_hook, mock_method, trigger, caplog):
+ async def test_run_loop_return_running_event(
+ self, mock_hook, mock_method, mock_get_pod, mock_container_is_running, trigger, caplog
+ ):
mock_hook.get_pod.return_value = self._mock_pod_result(mock.MagicMock())
mock_method.return_value = ContainerState.RUNNING
+ mock_container_is_running.return_value = True
caplog.set_level(logging.INFO)
@@ -178,7 +187,11 @@ class TestKubernetesPodTrigger:
mock_method.return_value = ContainerState.FAILED
expected_event = TriggerEvent(
- {"status": "failed", "namespace": "default", "name": "test-pod-name", "message": "pod failed"}
+ {
+ "pod_name": POD_NAME,
+ "namespace": NAMESPACE,
+ "status": "done",
+ }
)
actual_event = await trigger.run().asend(None)
@@ -197,14 +210,8 @@ class TestKubernetesPodTrigger:
generator = trigger.run()
actual = await generator.asend(None)
- actual_stack_trace = actual.payload.pop("stack_trace")
- assert (
- TriggerEvent(
- {"name": POD_NAME, "namespace": NAMESPACE, "status": "error", "message": "Test exception"}
- )
- == actual
- )
- assert actual_stack_trace.startswith("Traceback (most recent call last):")
+ actual_stack_trace = actual.payload.pop("description")
+ assert actual_stack_trace.startswith("Trigger KubernetesPodTrigger failed with exception Exception")
@pytest.mark.asyncio
@mock.patch(f"{TRIGGER_PATH}.define_container_state")
@@ -228,24 +235,16 @@ class TestKubernetesPodTrigger:
@pytest.mark.parametrize(
"logging_interval, exp_event",
[
- param(
- 0,
- {
- "status": "running",
- "last_log_time": DateTime(2022, 1, 1),
- "name": POD_NAME,
- "namespace": NAMESPACE,
- },
- id="short_interval",
- ),
+ param(0, {"status": "running", "last_log_time": DateTime(2022, 1, 1)}, id="short_interval"),
+ param(None, {"status": "done", "namespace": mock.ANY, "pod_name": mock.ANY}, id="no_interval"),
],
)
- @mock.patch(f"{TRIGGER_PATH}.define_container_state")
- @mock.patch(f"{TRIGGER_PATH}._wait_for_pod_start")
- @mock.patch("airflow.providers.cncf.kubernetes.triggers.pod.AsyncKubernetesHook.get_pod")
- async def test_running_log_interval(
- self, mock_get_pod, mock_wait_for_pod_start, define_container_state, logging_interval, exp_event
- ):
+ @mock.patch(
+ "kubernetes_asyncio.client.CoreV1Api.read_namespaced_pod",
+ new=get_read_pod_mock_containers([1, 1, None, None]),
+ )
+ @mock.patch("kubernetes_asyncio.config.load_kube_config")
+ async def test_running_log_interval(self, load_kube_config, logging_interval, exp_event):
"""
If log interval given, should emit event with running status and last log time.
Otherwise, should make it to second loop and emit "done" event.
@@ -255,15 +254,14 @@ class TestKubernetesPodTrigger:
interval is None, the second "running" status will just result in continuation of the loop. And
when in the next loop we get a non-running status, the trigger fires a "done" event.
"""
- define_container_state.return_value = "running"
trigger = KubernetesPodTrigger(
- pod_name=POD_NAME,
- pod_namespace=NAMESPACE,
- trigger_start_time=datetime.datetime.now(tz=datetime.timezone.utc),
- base_container_name=BASE_CONTAINER_NAME,
+ pod_name=mock.ANY,
+ pod_namespace=mock.ANY,
+ trigger_start_time=mock.ANY,
+ base_container_name=mock.ANY,
startup_timeout=5,
poll_interval=1,
- logging_interval=1,
+ logging_interval=logging_interval,
last_log_time=DateTime(2022, 1, 1),
)
assert await trigger.run().__anext__() == TriggerEvent(exp_event)
@@ -308,12 +306,12 @@ class TestKubernetesPodTrigger:
@pytest.mark.asyncio
@pytest.mark.parametrize("container_state", [ContainerState.WAITING, ContainerState.UNDEFINED])
- @mock.patch(f"{TRIGGER_PATH}.define_container_state")
+ @mock.patch(f"{TRIGGER_PATH}._wait_for_pod_start")
@mock.patch(f"{TRIGGER_PATH}.hook")
async def test_run_loop_return_timeout_event(
self, mock_hook, mock_method, trigger, caplog, container_state
):
- trigger.trigger_start_time = TRIGGER_START_TIME - datetime.timedelta(minutes=2)
+ trigger.trigger_start_time = TRIGGER_START_TIME - datetime.timedelta(seconds=5)
mock_hook.get_pod.return_value = self._mock_pod_result(
mock.MagicMock(
status=mock.MagicMock(
@@ -327,14 +325,4 @@ class TestKubernetesPodTrigger:
generator = trigger.run()
actual = await generator.asend(None)
- assert (
- TriggerEvent(
- {
- "name": POD_NAME,
- "namespace": NAMESPACE,
- "status": "timeout",
- "message": "Pod did not leave 'Pending' phase within specified timeout",
- }
- )
- == actual
- )
+ assert actual == TriggerEvent({"status": "done", "namespace": NAMESPACE, "pod_name": POD_NAME})
diff --git a/tests/providers/google/cloud/triggers/test_kubernetes_engine.py b/tests/providers/google/cloud/triggers/test_kubernetes_engine.py
index c6a2d4e72f..ca7b7ba358 100644
--- a/tests/providers/google/cloud/triggers/test_kubernetes_engine.py
+++ b/tests/providers/google/cloud/triggers/test_kubernetes_engine.py
@@ -108,20 +108,19 @@ class TestGKEStartPodTrigger:
}
@pytest.mark.asyncio
- @mock.patch(f"{TRIGGER_KUB_PATH}._wait_for_pod_start")
+ @mock.patch(f"{TRIGGER_KUB_PATH}.define_container_state")
@mock.patch(f"{TRIGGER_GKE_PATH}.hook")
async def test_run_loop_return_success_event_should_execute_successfully(
- self, mock_hook, mock_wait_pod, trigger
+ self, mock_hook, mock_method, trigger
):
mock_hook.get_pod.return_value = self._mock_pod_result(mock.MagicMock())
- mock_wait_pod.return_value = ContainerState.TERMINATED
+ mock_method.return_value = ContainerState.TERMINATED
expected_event = TriggerEvent(
{
- "name": POD_NAME,
+ "pod_name": POD_NAME,
"namespace": NAMESPACE,
- "status": "success",
- "message": "All containers inside pod have started successfully.",
+ "status": "done",
}
)
actual_event = await trigger.run().asend(None)
@@ -129,10 +128,10 @@ class TestGKEStartPodTrigger:
assert actual_event == expected_event
@pytest.mark.asyncio
- @mock.patch(f"{TRIGGER_KUB_PATH}._wait_for_pod_start")
+ @mock.patch(f"{TRIGGER_KUB_PATH}.define_container_state")
@mock.patch(f"{TRIGGER_GKE_PATH}.hook")
async def test_run_loop_return_failed_event_should_execute_successfully(
- self, mock_hook, mock_wait_pod, trigger
+ self, mock_hook, mock_method, trigger
):
mock_hook.get_pod.return_value = self._mock_pod_result(
mock.MagicMock(
@@ -141,14 +140,13 @@ class TestGKEStartPodTrigger:
)
)
)
- mock_wait_pod.return_value = ContainerState.FAILED
+ mock_method.return_value = ContainerState.FAILED
expected_event = TriggerEvent(
{
- "name": POD_NAME,
+ "pod_name": POD_NAME,
"namespace": NAMESPACE,
- "status": "failed",
- "message": "pod failed",
+ "status": "done",
}
)
actual_event = await trigger.run().asend(None)
@@ -156,15 +154,18 @@ class TestGKEStartPodTrigger:
assert actual_event == expected_event
@pytest.mark.asyncio
+ @mock.patch("airflow.providers.cncf.kubernetes.triggers.pod.container_is_running")
+ @mock.patch("airflow.providers.cncf.kubernetes.hooks.kubernetes.AsyncKubernetesHook.get_pod")
@mock.patch(f"{TRIGGER_KUB_PATH}._wait_for_pod_start")
- @mock.patch(f"{TRIGGER_KUB_PATH}.define_container_state")
@mock.patch(f"{TRIGGER_GKE_PATH}.hook")
async def test_run_loop_return_waiting_event_should_execute_successfully(
- self, mock_hook, mock_method, mock_wait_pod, trigger, caplog
+ self, mock_hook, mock_method, mock_get_pod, mock_container_is_running, trigger, caplog
):
mock_hook.get_pod.return_value = self._mock_pod_result(mock.MagicMock())
- mock_method.return_value = ContainerState.WAITING
+ mock_method.return_value = ContainerState.RUNNING
+ mock_container_is_running.return_value = True
+ trigger.logging_interval = 10
caplog.set_level(logging.INFO)
task = asyncio.create_task(trigger.run().__anext__())
@@ -175,13 +176,15 @@ class TestGKEStartPodTrigger:
assert f"Sleeping for {POLL_INTERVAL} seconds."
@pytest.mark.asyncio
+ @mock.patch("airflow.providers.cncf.kubernetes.triggers.pod.container_is_running")
+ @mock.patch("airflow.providers.cncf.kubernetes.hooks.kubernetes.AsyncKubernetesHook.get_pod")
@mock.patch(f"{TRIGGER_KUB_PATH}._wait_for_pod_start")
- @mock.patch(f"{TRIGGER_KUB_PATH}.define_container_state")
@mock.patch(f"{TRIGGER_GKE_PATH}.hook")
async def test_run_loop_return_running_event_should_execute_successfully(
- self, mock_hook, mock_method, mock_wait_pod, trigger, caplog
+ self, mock_hook, mock_method, mock_get_pod, mock_container_is_running, trigger, caplog
):
mock_hook.get_pod.return_value = self._mock_pod_result(mock.MagicMock())
+ mock_container_is_running.return_value = True
mock_method.return_value = ContainerState.RUNNING
caplog.set_level(logging.INFO)
@@ -194,10 +197,9 @@ class TestGKEStartPodTrigger:
assert f"Sleeping for {POLL_INTERVAL} seconds."
@pytest.mark.asyncio
- @mock.patch(f"{TRIGGER_KUB_PATH}._wait_for_pod_start")
@mock.patch(f"{TRIGGER_GKE_PATH}.hook")
async def test_logging_in_trigger_when_exception_should_execute_successfully(
- self, mock_hook, mock_wait_pod, trigger, caplog
+ self, mock_hook, trigger, caplog
):
"""
Test that GKEStartPodTrigger fires the correct event in case of an error.
@@ -206,14 +208,9 @@ class TestGKEStartPodTrigger:
generator = trigger.run()
actual = await generator.asend(None)
- actual_stack_trace = actual.payload.pop("stack_trace")
- assert (
- TriggerEvent(
- {"name": POD_NAME, "namespace": NAMESPACE, "status": "error", "message": "Test exception"}
- )
- == actual
- )
- assert actual_stack_trace.startswith("Traceback (most recent call last):")
+
+ actual_stack_trace = actual.payload.pop("description")
+ assert actual_stack_trace.startswith("Trigger GKEStartPodTrigger failed with exception Exception")
@pytest.mark.asyncio
@mock.patch(f"{TRIGGER_KUB_PATH}.define_container_state")