You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by as...@apache.org on 2021/04/15 12:07:28 UTC

[airflow] 09/36: Fix celery executor bug trying to call len on map (#14883)

This is an automated email from the ASF dual-hosted git repository.

ash pushed a commit to branch v2-0-test
in repository https://gitbox.apache.org/repos/asf/airflow.git

commit 22b2a800ba81e2a90ef40b7a92eb80d4eb67acb2
Author: Ryan Hatter <25...@users.noreply.github.com>
AuthorDate: Tue Apr 6 05:21:38 2021 -0400

    Fix celery executor bug trying to call len on map (#14883)
    
    Co-authored-by: RNHTTR <ry...@wiftapp.com>
    (cherry picked from commit 4ee442970873ba59ee1d1de3ac78ef8e33666e0f)
---
 airflow/executors/celery_executor.py    | 22 ++++++++++-----------
 tests/executors/test_celery_executor.py | 35 +++++++++++++++++++++++----------
 2 files changed, 35 insertions(+), 22 deletions(-)

diff --git a/airflow/executors/celery_executor.py b/airflow/executors/celery_executor.py
index a670294..2d0e915 100644
--- a/airflow/executors/celery_executor.py
+++ b/airflow/executors/celery_executor.py
@@ -476,7 +476,7 @@ class CeleryExecutor(BaseExecutor):
             return tis
 
         states_by_celery_task_id = self.bulk_state_fetcher.get_many(
-            map(operator.itemgetter(0), celery_tasks.values())
+            list(map(operator.itemgetter(0), celery_tasks.values()))
         )
 
         adopted = []
@@ -526,10 +526,6 @@ def fetch_celery_task_state(async_result: AsyncResult) -> Tuple[str, Union[str,
         return async_result.task_id, ExceptionWithTraceback(e, exception_traceback), None
 
 
-def _tasks_list_to_task_ids(async_tasks) -> Set[str]:
-    return {a.task_id for a in async_tasks}
-
-
 class BulkStateFetcher(LoggingMixin):
     """
     Gets status for many Celery tasks using the best method available
@@ -543,20 +539,22 @@ class BulkStateFetcher(LoggingMixin):
         super().__init__()
         self._sync_parallelism = sync_parralelism
 
+    def _tasks_list_to_task_ids(self, async_tasks) -> Set[str]:
+        return {a.task_id for a in async_tasks}
+
     def get_many(self, async_results) -> Mapping[str, EventBufferValueType]:
         """Gets status for many Celery tasks using the best method available."""
         if isinstance(app.backend, BaseKeyValueStoreBackend):
             result = self._get_many_from_kv_backend(async_results)
-            return result
-        if isinstance(app.backend, DatabaseBackend):
+        elif isinstance(app.backend, DatabaseBackend):
             result = self._get_many_from_db_backend(async_results)
-            return result
-        result = self._get_many_using_multiprocessing(async_results)
-        self.log.debug("Fetched %d states for %d task", len(result), len(async_results))
+        else:
+            result = self._get_many_using_multiprocessing(async_results)
+        self.log.debug("Fetched %d state(s) for %d task(s)", len(result), len(async_results))
         return result
 
     def _get_many_from_kv_backend(self, async_tasks) -> Mapping[str, EventBufferValueType]:
-        task_ids = _tasks_list_to_task_ids(async_tasks)
+        task_ids = self._tasks_list_to_task_ids(async_tasks)
         keys = [app.backend.get_key_for_task(k) for k in task_ids]
         values = app.backend.mget(keys)
         task_results = [app.backend.decode_result(v) for v in values if v]
@@ -565,7 +563,7 @@ class BulkStateFetcher(LoggingMixin):
         return self._prepare_state_and_info_by_task_dict(task_ids, task_results_by_task_id)
 
     def _get_many_from_db_backend(self, async_tasks) -> Mapping[str, EventBufferValueType]:
-        task_ids = _tasks_list_to_task_ids(async_tasks)
+        task_ids = self._tasks_list_to_task_ids(async_tasks)
         session = app.backend.ResultSession()
         task_cls = getattr(app.backend, "task_cls", TaskDb)
         with session_cleanup(session):
diff --git a/tests/executors/test_celery_executor.py b/tests/executors/test_celery_executor.py
index 944fa49..4f93007 100644
--- a/tests/executors/test_celery_executor.py
+++ b/tests/executors/test_celery_executor.py
@@ -414,7 +414,9 @@ class TestBulkStateFetcher(unittest.TestCase):
     def test_should_support_kv_backend(self, mock_mget):
         with _prepare_app():
             mock_backend = BaseKeyValueStoreBackend(app=celery_executor.app)
-            with mock.patch.object(celery_executor.app, 'backend', mock_backend):
+            with mock.patch.object(celery_executor.app, 'backend', mock_backend), self.assertLogs(
+                "airflow.executors.celery_executor.BulkStateFetcher", level="DEBUG"
+            ) as cm:
                 fetcher = BulkStateFetcher()
                 result = fetcher.get_many(
                     [
@@ -429,6 +431,9 @@ class TestBulkStateFetcher(unittest.TestCase):
         mock_mget.assert_called_once_with(mock.ANY)
 
         assert result == {'123': ('SUCCESS', None), '456': ("PENDING", None)}
+        assert [
+            'DEBUG:airflow.executors.celery_executor.BulkStateFetcher:Fetched 2 state(s) for 2 task(s)'
+        ] == cm.output
 
     @mock.patch("celery.backends.database.DatabaseBackend.ResultSession")
     @pytest.mark.integration("redis")
@@ -438,21 +443,26 @@ class TestBulkStateFetcher(unittest.TestCase):
         with _prepare_app():
             mock_backend = DatabaseBackend(app=celery_executor.app, url="sqlite3://")
 
-            with mock.patch.object(celery_executor.app, 'backend', mock_backend):
+            with mock.patch.object(celery_executor.app, 'backend', mock_backend), self.assertLogs(
+                "airflow.executors.celery_executor.BulkStateFetcher", level="DEBUG"
+            ) as cm:
                 mock_session = mock_backend.ResultSession.return_value  # pylint: disable=no-member
                 mock_session.query.return_value.filter.return_value.all.return_value = [
                     mock.MagicMock(**{"to_dict.return_value": {"status": "SUCCESS", "task_id": "123"}})
                 ]
 
-        fetcher = BulkStateFetcher()
-        result = fetcher.get_many(
-            [
-                mock.MagicMock(task_id="123"),
-                mock.MagicMock(task_id="456"),
-            ]
-        )
+                fetcher = BulkStateFetcher()
+                result = fetcher.get_many(
+                    [
+                        mock.MagicMock(task_id="123"),
+                        mock.MagicMock(task_id="456"),
+                    ]
+                )
 
         assert result == {'123': ('SUCCESS', None), '456': ("PENDING", None)}
+        assert [
+            'DEBUG:airflow.executors.celery_executor.BulkStateFetcher:Fetched 2 state(s) for 2 task(s)'
+        ] == cm.output
 
     @pytest.mark.integration("redis")
     @pytest.mark.integration("rabbitmq")
@@ -461,7 +471,9 @@ class TestBulkStateFetcher(unittest.TestCase):
         with _prepare_app():
             mock_backend = mock.MagicMock(autospec=BaseBackend)
 
-            with mock.patch.object(celery_executor.app, 'backend', mock_backend):
+            with mock.patch.object(celery_executor.app, 'backend', mock_backend), self.assertLogs(
+                "airflow.executors.celery_executor.BulkStateFetcher", level="DEBUG"
+            ) as cm:
                 fetcher = BulkStateFetcher(1)
                 result = fetcher.get_many(
                     [
@@ -471,3 +483,6 @@ class TestBulkStateFetcher(unittest.TestCase):
                 )
 
         assert result == {'123': ('SUCCESS', None), '456': ("PENDING", None)}
+        assert [
+            'DEBUG:airflow.executors.celery_executor.BulkStateFetcher:Fetched 2 state(s) for 2 task(s)'
+        ] == cm.output