You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by GitBox <gi...@apache.org> on 2020/05/06 11:14:20 UTC

[GitHub] [airflow] mik-laj commented on a change in pull request #7542: [AIRFLOW-6921] Fetch celery states in bulk

mik-laj commented on a change in pull request #7542:
URL: https://github.com/apache/airflow/pull/7542#discussion_r420712216



##########
File path: airflow/executors/celery_executor.py
##########
@@ -319,3 +262,109 @@ def execute_async(self,
 
     def terminate(self):
         pass
+
+
+def fetch_celery_task_state(async_result: AsyncResult) -> Tuple[str, Union[str, ExceptionWithTraceback]]:
+    """
+    Fetch and return the state of the given celery task. The scope of this function is
+    global so that it can be called by subprocesses in the pool.
+
+    :param async_result: a tuple of the Celery task key and the async Celery object used
+        to fetch the task's state
+    :type async_result: tuple(str, celery.result.AsyncResult)
+    :return: a tuple of the Celery task key and the Celery state of the task
+    :rtype: tuple[str, str]
+    """
+
+    try:
+        with timeout(seconds=OPERATION_TIMEOUT):
+            # Accessing state property of celery task will make actual network request
+            # to get the current state of the task
+            return async_result.task_id, async_result.state
+    except Exception as e:  # pylint: disable=broad-except
+        exception_traceback = "Celery Task ID: {}\n{}".format(async_result, traceback.format_exc())
+        return async_result.task_id, ExceptionWithTraceback(e, exception_traceback)
+
+
+def _tasks_list_to_task_ids(async_tasks) -> Set[str]:
+    return {a.task_id for a in async_tasks}
+
+
+class BulkStateFetcher(LoggingMixin):
+    """
+    Gets status for many Celery tasks using the best method available
+
+    If BaseKeyValueStoreBackend is used as result backend, the mget method is used.
+    If DatabaseBackend is used as result backend, the SELECT ...WHER task_id IN (...) query is used
+    Otherwise, multiprocessing.Pool will be used. Each task status will be downloaded individually.
+    """
+    def __init__(self, sync_parralelism=None):
+        super().__init__()
+        self._sync_parallelism = sync_parralelism
+
+    def get_many(self, async_results) -> Mapping[str, str]:
+        """
+        Gets status for many Celery tasks using the best method available.
+        """
+        if isinstance(app.backend, BaseKeyValueStoreBackend):
+            result = self._get_many_from_kv_backend(async_results)
+            return result
+        if isinstance(app.backend, DatabaseBackend):
+            result = self._get_many_from_db_backend(async_results)
+            return result
+        result = self._get_many_using_multiprocessing(async_results)
+        self.log.debug("Fetched %d states for %d task", len(result), len(async_results))
+        return result
+
+    def _get_many_from_kv_backend(self, async_tasks) -> Mapping[str, str]:

Review comment:
       In all cases, the description of these methods will be identical I am afraid. The description of the method should contain a description of the behavior, and the behavior is identical everywhere. Only the implementation details are different.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org