You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by el...@apache.org on 2023/03/09 07:01:52 UTC

[airflow] branch main updated: try_number was not being passed to the get_task_log method, instead (#28817)

This is an automated email from the ASF dual-hosted git repository.

eladkal pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new f5ed4d599c try_number was not being passed to the get_task_log method, instead (#28817)
f5ed4d599c is described below

commit f5ed4d599cadd90ca097e304234253e8aa2adad9
Author: sanjayp <sa...@gmail.com>
AuthorDate: Thu Mar 9 01:01:43 2023 -0600

    try_number was not being passed to the get_task_log method, instead (#28817)
    
    ti.try_number was used for fetching log from k8s pod.
    it was causing incorrect log being returned for k8s pod.
    fixed by passing try_number from _read to get_task_log method
    use try_number argument instead of ti.try_number for selecting pod in
    k8s executor
    
    Co-authored-by: eladkal <45...@users.noreply.github.com>
---
 airflow/executors/base_executor.py                 | 6 +++---
 airflow/executors/celery_kubernetes_executor.py    | 4 ++--
 airflow/executors/kubernetes_executor.py           | 4 ++--
 airflow/executors/local_kubernetes_executor.py     | 4 ++--
 airflow/utils/log/file_task_handler.py             | 4 ++--
 tests/executors/test_base_executor.py              | 2 +-
 tests/executors/test_celery_kubernetes_executor.py | 6 +++---
 tests/executors/test_kubernetes_executor.py        | 4 ++--
 tests/executors/test_local_kubernetes_executor.py  | 6 +++---
 tests/utils/test_log_handlers.py                   | 8 ++++----
 10 files changed, 24 insertions(+), 24 deletions(-)

diff --git a/airflow/executors/base_executor.py b/airflow/executors/base_executor.py
index 2b6f6e2fe4..db8cdd2d56 100644
--- a/airflow/executors/base_executor.py
+++ b/airflow/executors/base_executor.py
@@ -357,13 +357,13 @@ class BaseExecutor(LoggingMixin):
         """
         raise NotImplementedError()
 
-    def get_task_log(self, ti: TaskInstance) -> tuple[list[str], list[str]]:
+    def get_task_log(self, ti: TaskInstance, try_number: int) -> tuple[list[str], list[str]]:
         """
         This method can be implemented by any child class to return the task logs.
 
         :param ti: A TaskInstance object
-        :param log: log str
-        :return: logs or tuple of logs and meta dict
+        :param try_number: current try_number to read log from
+        :return: tuple of logs and messages
         """
         return [], []
 
diff --git a/airflow/executors/celery_kubernetes_executor.py b/airflow/executors/celery_kubernetes_executor.py
index ea04b8a42b..94e9df684e 100644
--- a/airflow/executors/celery_kubernetes_executor.py
+++ b/airflow/executors/celery_kubernetes_executor.py
@@ -143,10 +143,10 @@ class CeleryKubernetesExecutor(LoggingMixin):
             cfg_path=cfg_path,
         )
 
-    def get_task_log(self, ti: TaskInstance) -> tuple[list[str], list[str]]:
+    def get_task_log(self, ti: TaskInstance, try_number: int) -> tuple[list[str], list[str]]:
         """Fetch task log from Kubernetes executor"""
         if ti.queue == self.kubernetes_executor.kubernetes_queue:
-            return self.kubernetes_executor.get_task_log(ti=ti)
+            return self.kubernetes_executor.get_task_log(ti=ti, try_number=try_number)
         return [], []
 
     def has_task(self, task_instance: TaskInstance) -> bool:
diff --git a/airflow/executors/kubernetes_executor.py b/airflow/executors/kubernetes_executor.py
index 962e64cd48..276afba321 100644
--- a/airflow/executors/kubernetes_executor.py
+++ b/airflow/executors/kubernetes_executor.py
@@ -781,7 +781,7 @@ class KubernetesExecutor(BaseExecutor):
             namespace = pod_override.metadata.namespace
         return namespace or conf.get("kubernetes_executor", "namespace", fallback="default")
 
-    def get_task_log(self, ti: TaskInstance) -> tuple[list[str], list[str]]:
+    def get_task_log(self, ti: TaskInstance, try_number: int) -> tuple[list[str], list[str]]:
         messages = []
         log = []
         try:
@@ -794,7 +794,7 @@ class KubernetesExecutor(BaseExecutor):
             selector = PodGenerator.build_selector_for_k8s_executor_pod(
                 dag_id=ti.dag_id,
                 task_id=ti.task_id,
-                try_number=ti.try_number,
+                try_number=try_number,
                 map_index=ti.map_index,
                 run_id=ti.run_id,
                 airflow_worker=ti.queued_by_job_id,
diff --git a/airflow/executors/local_kubernetes_executor.py b/airflow/executors/local_kubernetes_executor.py
index 059aa11dcb..9ce34dce8b 100644
--- a/airflow/executors/local_kubernetes_executor.py
+++ b/airflow/executors/local_kubernetes_executor.py
@@ -144,10 +144,10 @@ class LocalKubernetesExecutor(LoggingMixin):
             cfg_path=cfg_path,
         )
 
-    def get_task_log(self, ti: TaskInstance) -> tuple[list[str], list[str]]:
+    def get_task_log(self, ti: TaskInstance, try_number: int) -> tuple[list[str], list[str]]:
         """Fetch task log from kubernetes executor"""
         if ti.queue == self.kubernetes_executor.kubernetes_queue:
-            return self.kubernetes_executor.get_task_log(ti=ti)
+            return self.kubernetes_executor.get_task_log(ti=ti, try_number=try_number)
         return [], []
 
     def has_task(self, task_instance: TaskInstance) -> bool:
diff --git a/airflow/utils/log/file_task_handler.py b/airflow/utils/log/file_task_handler.py
index e1c9d5b504..6b597aa780 100644
--- a/airflow/utils/log/file_task_handler.py
+++ b/airflow/utils/log/file_task_handler.py
@@ -266,7 +266,7 @@ class FileTaskHandler(logging.Handler):
         return False
 
     @cached_property
-    def _executor_get_task_log(self) -> Callable[[TaskInstance], tuple[list[str], list[str]]]:
+    def _executor_get_task_log(self) -> Callable[[TaskInstance, int], tuple[list[str], list[str]]]:
         """This cached property avoids loading executor repeatedly."""
         executor = ExecutorLoader.get_default_executor()
         return executor.get_task_log
@@ -312,7 +312,7 @@ class FileTaskHandler(logging.Handler):
             remote_messages, remote_logs = self._read_remote_logs(ti, try_number, metadata)
             messages_list.extend(remote_messages)
         if ti.state == TaskInstanceState.RUNNING:
-            response = self._executor_get_task_log(ti)
+            response = self._executor_get_task_log(ti, try_number)
             if response:
                 executor_messages, executor_logs = response
             if executor_messages:
diff --git a/tests/executors/test_base_executor.py b/tests/executors/test_base_executor.py
index b7d14078da..d57815fb36 100644
--- a/tests/executors/test_base_executor.py
+++ b/tests/executors/test_base_executor.py
@@ -55,7 +55,7 @@ def test_is_production_default_value():
 def test_get_task_log():
     executor = BaseExecutor()
     ti = TaskInstance(task=BaseOperator(task_id="dummy"))
-    assert executor.get_task_log(ti=ti) == ([], [])
+    assert executor.get_task_log(ti=ti, try_number=1) == ([], [])
 
 
 def test_serve_logs_default_value():
diff --git a/tests/executors/test_celery_kubernetes_executor.py b/tests/executors/test_celery_kubernetes_executor.py
index 8da4818003..ec222e71f7 100644
--- a/tests/executors/test_celery_kubernetes_executor.py
+++ b/tests/executors/test_celery_kubernetes_executor.py
@@ -185,13 +185,13 @@ class TestCeleryKubernetesExecutor:
         cke = CeleryKubernetesExecutor(celery_executor_mock, k8s_executor_mock)
         simple_task_instance = mock.MagicMock()
         simple_task_instance.queue = KUBERNETES_QUEUE
-        cke.get_task_log(ti=simple_task_instance)
-        k8s_executor_mock.get_task_log.assert_called_once_with(ti=simple_task_instance)
+        cke.get_task_log(ti=simple_task_instance, try_number=1)
+        k8s_executor_mock.get_task_log.assert_called_once_with(ti=simple_task_instance, try_number=1)
 
         k8s_executor_mock.reset_mock()
 
         simple_task_instance.queue = "test-queue"
-        log = cke.get_task_log(ti=simple_task_instance)
+        log = cke.get_task_log(ti=simple_task_instance, try_number=1)
         k8s_executor_mock.get_task_log.assert_not_called()
         assert log == ([], [])
 
diff --git a/tests/executors/test_kubernetes_executor.py b/tests/executors/test_kubernetes_executor.py
index 954ba9c7d0..65c08b552c 100644
--- a/tests/executors/test_kubernetes_executor.py
+++ b/tests/executors/test_kubernetes_executor.py
@@ -1231,7 +1231,7 @@ class TestKubernetesExecutor:
         ti = create_task_instance_of_operator(EmptyOperator, dag_id="test_k8s_log_dag", task_id="test_task")
 
         executor = KubernetesExecutor()
-        messages, logs = executor.get_task_log(ti=ti)
+        messages, logs = executor.get_task_log(ti=ti, try_number=1)
 
         mock_kube_client.read_namespaced_pod_log.assert_called_once()
         assert "Trying to get logs (last 100 lines) from worker pod " in messages
@@ -1240,7 +1240,7 @@ class TestKubernetesExecutor:
         mock_kube_client.reset_mock()
         mock_kube_client.read_namespaced_pod_log.side_effect = Exception("error_fetching_pod_log")
 
-        messages, logs = executor.get_task_log(ti=ti)
+        messages, logs = executor.get_task_log(ti=ti, try_number=1)
         assert logs == [""]
         assert messages == [
             "Trying to get logs (last 100 lines) from worker pod ",
diff --git a/tests/executors/test_local_kubernetes_executor.py b/tests/executors/test_local_kubernetes_executor.py
index f92279cfea..e4d2fdd371 100644
--- a/tests/executors/test_local_kubernetes_executor.py
+++ b/tests/executors/test_local_kubernetes_executor.py
@@ -96,11 +96,11 @@ class TestLocalKubernetesExecutor:
         local_k8s_exec = LocalKubernetesExecutor(local_executor_mock, k8s_executor_mock)
         simple_task_instance = mock.MagicMock()
         simple_task_instance.queue = conf.get("local_kubernetes_executor", "kubernetes_queue")
-        local_k8s_exec.get_task_log(ti=simple_task_instance)
-        k8s_executor_mock.get_task_log.assert_called_once_with(ti=simple_task_instance)
+        local_k8s_exec.get_task_log(ti=simple_task_instance, try_number=3)
+        k8s_executor_mock.get_task_log.assert_called_once_with(ti=simple_task_instance, try_number=3)
         k8s_executor_mock.reset_mock()
         simple_task_instance.queue = "test-queue"
-        messages, logs = local_k8s_exec.get_task_log(ti=simple_task_instance)
+        messages, logs = local_k8s_exec.get_task_log(ti=simple_task_instance, try_number=3)
         k8s_executor_mock.get_task_log.assert_not_called()
         assert logs == []
         assert messages == []
diff --git a/tests/utils/test_log_handlers.py b/tests/utils/test_log_handlers.py
index 3f762b9a94..138301e161 100644
--- a/tests/utils/test_log_handlers.py
+++ b/tests/utils/test_log_handlers.py
@@ -288,9 +288,9 @@ class TestFileTaskLogHandler:
         ti.triggerer_job = None
         with conf_vars({("core", "executor"): executor_name}):
             fth = FileTaskHandler("")
-            fth._read(ti=ti, try_number=1)
+            fth._read(ti=ti, try_number=2)
         if state == TaskInstanceState.RUNNING:
-            mock_k8s_get_task_log.assert_called_once_with(ti)
+            mock_k8s_get_task_log.assert_called_once_with(ti, 2)
         else:
             mock_k8s_get_task_log.assert_not_called()
 
@@ -357,7 +357,7 @@ class TestFileTaskLogHandler:
         set_context(logger, ti)
         ti.run(ignore_ti_state=True)
         ti.state = TaskInstanceState.RUNNING
-        file_handler.read(ti, 3)
+        file_handler.read(ti, 2)
 
         # first we find pod name
         mock_list_pod.assert_called_once()
@@ -372,7 +372,7 @@ class TestFileTaskLogHandler:
                     "kubernetes_executor=True",
                     "run_id=manual__2016-01-01T0000000000-2b88d1d57",
                     "task_id=task_for_testing_file_log_handler",
-                    "try_number=.+?",
+                    "try_number=2",
                     "airflow-worker",
                 ]
             ),