You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by as...@apache.org on 2021/04/15 12:07:28 UTC
[airflow] 09/36: Fix celery executor bug trying to call len on map
(#14883)
This is an automated email from the ASF dual-hosted git repository.
ash pushed a commit to branch v2-0-test
in repository https://gitbox.apache.org/repos/asf/airflow.git
commit 22b2a800ba81e2a90ef40b7a92eb80d4eb67acb2
Author: Ryan Hatter <25...@users.noreply.github.com>
AuthorDate: Tue Apr 6 05:21:38 2021 -0400
Fix celery executor bug trying to call len on map (#14883)
Co-authored-by: RNHTTR <ry...@wiftapp.com>
(cherry picked from commit 4ee442970873ba59ee1d1de3ac78ef8e33666e0f)
---
airflow/executors/celery_executor.py | 22 ++++++++++-----------
tests/executors/test_celery_executor.py | 35 +++++++++++++++++++++++----------
2 files changed, 35 insertions(+), 22 deletions(-)
diff --git a/airflow/executors/celery_executor.py b/airflow/executors/celery_executor.py
index a670294..2d0e915 100644
--- a/airflow/executors/celery_executor.py
+++ b/airflow/executors/celery_executor.py
@@ -476,7 +476,7 @@ class CeleryExecutor(BaseExecutor):
return tis
states_by_celery_task_id = self.bulk_state_fetcher.get_many(
- map(operator.itemgetter(0), celery_tasks.values())
+ list(map(operator.itemgetter(0), celery_tasks.values()))
)
adopted = []
@@ -526,10 +526,6 @@ def fetch_celery_task_state(async_result: AsyncResult) -> Tuple[str, Union[str,
return async_result.task_id, ExceptionWithTraceback(e, exception_traceback), None
-def _tasks_list_to_task_ids(async_tasks) -> Set[str]:
- return {a.task_id for a in async_tasks}
-
-
class BulkStateFetcher(LoggingMixin):
"""
Gets status for many Celery tasks using the best method available
@@ -543,20 +539,22 @@ class BulkStateFetcher(LoggingMixin):
super().__init__()
self._sync_parallelism = sync_parralelism
+ def _tasks_list_to_task_ids(self, async_tasks) -> Set[str]:
+ return {a.task_id for a in async_tasks}
+
def get_many(self, async_results) -> Mapping[str, EventBufferValueType]:
"""Gets status for many Celery tasks using the best method available."""
if isinstance(app.backend, BaseKeyValueStoreBackend):
result = self._get_many_from_kv_backend(async_results)
- return result
- if isinstance(app.backend, DatabaseBackend):
+ elif isinstance(app.backend, DatabaseBackend):
result = self._get_many_from_db_backend(async_results)
- return result
- result = self._get_many_using_multiprocessing(async_results)
- self.log.debug("Fetched %d states for %d task", len(result), len(async_results))
+ else:
+ result = self._get_many_using_multiprocessing(async_results)
+ self.log.debug("Fetched %d state(s) for %d task(s)", len(result), len(async_results))
return result
def _get_many_from_kv_backend(self, async_tasks) -> Mapping[str, EventBufferValueType]:
- task_ids = _tasks_list_to_task_ids(async_tasks)
+ task_ids = self._tasks_list_to_task_ids(async_tasks)
keys = [app.backend.get_key_for_task(k) for k in task_ids]
values = app.backend.mget(keys)
task_results = [app.backend.decode_result(v) for v in values if v]
@@ -565,7 +563,7 @@ class BulkStateFetcher(LoggingMixin):
return self._prepare_state_and_info_by_task_dict(task_ids, task_results_by_task_id)
def _get_many_from_db_backend(self, async_tasks) -> Mapping[str, EventBufferValueType]:
- task_ids = _tasks_list_to_task_ids(async_tasks)
+ task_ids = self._tasks_list_to_task_ids(async_tasks)
session = app.backend.ResultSession()
task_cls = getattr(app.backend, "task_cls", TaskDb)
with session_cleanup(session):
diff --git a/tests/executors/test_celery_executor.py b/tests/executors/test_celery_executor.py
index 944fa49..4f93007 100644
--- a/tests/executors/test_celery_executor.py
+++ b/tests/executors/test_celery_executor.py
@@ -414,7 +414,9 @@ class TestBulkStateFetcher(unittest.TestCase):
def test_should_support_kv_backend(self, mock_mget):
with _prepare_app():
mock_backend = BaseKeyValueStoreBackend(app=celery_executor.app)
- with mock.patch.object(celery_executor.app, 'backend', mock_backend):
+ with mock.patch.object(celery_executor.app, 'backend', mock_backend), self.assertLogs(
+ "airflow.executors.celery_executor.BulkStateFetcher", level="DEBUG"
+ ) as cm:
fetcher = BulkStateFetcher()
result = fetcher.get_many(
[
@@ -429,6 +431,9 @@ class TestBulkStateFetcher(unittest.TestCase):
mock_mget.assert_called_once_with(mock.ANY)
assert result == {'123': ('SUCCESS', None), '456': ("PENDING", None)}
+ assert [
+ 'DEBUG:airflow.executors.celery_executor.BulkStateFetcher:Fetched 2 state(s) for 2 task(s)'
+ ] == cm.output
@mock.patch("celery.backends.database.DatabaseBackend.ResultSession")
@pytest.mark.integration("redis")
@@ -438,21 +443,26 @@ class TestBulkStateFetcher(unittest.TestCase):
with _prepare_app():
mock_backend = DatabaseBackend(app=celery_executor.app, url="sqlite3://")
- with mock.patch.object(celery_executor.app, 'backend', mock_backend):
+ with mock.patch.object(celery_executor.app, 'backend', mock_backend), self.assertLogs(
+ "airflow.executors.celery_executor.BulkStateFetcher", level="DEBUG"
+ ) as cm:
mock_session = mock_backend.ResultSession.return_value # pylint: disable=no-member
mock_session.query.return_value.filter.return_value.all.return_value = [
mock.MagicMock(**{"to_dict.return_value": {"status": "SUCCESS", "task_id": "123"}})
]
- fetcher = BulkStateFetcher()
- result = fetcher.get_many(
- [
- mock.MagicMock(task_id="123"),
- mock.MagicMock(task_id="456"),
- ]
- )
+ fetcher = BulkStateFetcher()
+ result = fetcher.get_many(
+ [
+ mock.MagicMock(task_id="123"),
+ mock.MagicMock(task_id="456"),
+ ]
+ )
assert result == {'123': ('SUCCESS', None), '456': ("PENDING", None)}
+ assert [
+ 'DEBUG:airflow.executors.celery_executor.BulkStateFetcher:Fetched 2 state(s) for 2 task(s)'
+ ] == cm.output
@pytest.mark.integration("redis")
@pytest.mark.integration("rabbitmq")
@@ -461,7 +471,9 @@ class TestBulkStateFetcher(unittest.TestCase):
with _prepare_app():
mock_backend = mock.MagicMock(autospec=BaseBackend)
- with mock.patch.object(celery_executor.app, 'backend', mock_backend):
+ with mock.patch.object(celery_executor.app, 'backend', mock_backend), self.assertLogs(
+ "airflow.executors.celery_executor.BulkStateFetcher", level="DEBUG"
+ ) as cm:
fetcher = BulkStateFetcher(1)
result = fetcher.get_many(
[
@@ -471,3 +483,6 @@ class TestBulkStateFetcher(unittest.TestCase):
)
assert result == {'123': ('SUCCESS', None), '456': ("PENDING", None)}
+ assert [
+ 'DEBUG:airflow.executors.celery_executor.BulkStateFetcher:Fetched 2 state(s) for 2 task(s)'
+ ] == cm.output