You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by GitBox <gi...@apache.org> on 2019/01/10 22:49:12 UTC

[GitHub] stale[bot] closed pull request #2538: [AIRFLOW-1491] Recover celery queue on restart

stale[bot] closed pull request #2538: [AIRFLOW-1491] Recover celery queue on restart
URL: https://github.com/apache/airflow/pull/2538
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/airflow/executors/base_executor.py b/airflow/executors/base_executor.py
index 7a4065eb07..eb8bfc8205 100644
--- a/airflow/executors/base_executor.py
+++ b/airflow/executors/base_executor.py
@@ -120,7 +120,7 @@ def heartbeat(self):
             self.queued_tasks.pop(key)
             ti.refresh_from_db()
             if ti.state != State.RUNNING:
-                self.running[key] = command
+                self.running[key] = True
                 self.execute_async(key, command=command, queue=queue)
             else:
                 self.logger.debug(
diff --git a/airflow/executors/celery_executor.py b/airflow/executors/celery_executor.py
index 17c343bd4a..ef8d82c13c 100644
--- a/airflow/executors/celery_executor.py
+++ b/airflow/executors/celery_executor.py
@@ -13,6 +13,7 @@
 # limitations under the License.
 
 from builtins import object
+
 import logging
 import subprocess
 import ssl
@@ -20,10 +21,13 @@
 import traceback
 
 from celery import Celery
+from celery.result import AsyncResult
 from celery import states as celery_states
 
 from airflow.exceptions import AirflowConfigException, AirflowException
 from airflow.executors.base_executor import BaseExecutor
+from airflow.models import ExecutorQueueManager
+
 from airflow import configuration
 
 PARALLELISM = configuration.get('core', 'PARALLELISM')
@@ -49,10 +53,13 @@ class CeleryConfig(object):
     CELERY_DEFAULT_QUEUE = DEFAULT_QUEUE
     CELERY_DEFAULT_EXCHANGE = DEFAULT_QUEUE
 
+    #CELERY_SEND_EVENTS = True
+    #CELERY_EVENT_QUEUE_EXPIRES = 120
+
     celery_ssl_active = False
     try:
         celery_ssl_active = configuration.getboolean('celery', 'CELERY_SSL_ACTIVE')
-    except AirflowConfigException as e:
+    except AirflowConfigException:
         logging.warning("Celery Executor will run without SSL")
 
     try:
@@ -61,10 +68,10 @@ class CeleryConfig(object):
                               'certfile': configuration.get('celery', 'CELERY_SSL_CERT'),
                               'ca_certs': configuration.get('celery', 'CELERY_SSL_CACERT'),
                               'cert_reqs': ssl.CERT_REQUIRED}
-    except AirflowConfigException as e:
+    except AirflowConfigException:
         raise AirflowException('AirflowConfigException: CELERY_SSL_ACTIVE is True, please ensure CELERY_SSL_KEY, '
                                'CELERY_SSL_CERT and CELERY_SSL_CACERT are set')
-    except Exception as e:
+    except Exception:
         raise AirflowException('Exception: There was an unknown Celery SSL Error.  Please ensure you want to use '
                                'SSL and/or have all necessary certs and key.')
 
@@ -74,13 +81,13 @@ class CeleryConfig(object):
 
 
 @app.task
-def execute_command(command):
-    logging.info("Executing command in Celery " + command)
+def execute_command(command, key):
+    logging.info("[celery] executing command {} for {} ".format(command, key))
     try:
         subprocess.check_call(command, shell=True)
     except subprocess.CalledProcessError as e:
         logging.error(e)
-        raise AirflowException('Celery command failed')
+        raise AirflowException('Celery command failed for {}'.format(key))
 
 
 class CeleryExecutor(BaseExecutor):
@@ -92,49 +99,60 @@ class CeleryExecutor(BaseExecutor):
     vast amounts of messages, while providing operations with the tools
     required to maintain such a system.
     """
+    def __init__(self, parallelism=PARALLELISM):
+        super(CeleryExecutor, self).__init__(parallelism=parallelism)
+        self.tasks = ExecutorQueueManager()
 
     def start(self):
-        self.tasks = {}
-        self.last_state = {}
+        self._recover_queue()
 
     def execute_async(self, key, command, queue=DEFAULT_QUEUE):
-        self.logger.info( "[celery] queuing {key} through celery, "
-                       "queue={queue}".format(**locals()))
+        self.logger.info("[celery] queuing {key} through celery, "
+                         "queue={queue}".format(**locals()))
         self.tasks[key] = execute_command.apply_async(
-            args=[command], queue=queue)
-        self.last_state[key] = celery_states.PENDING
+            args=[command, key], queue=queue)
 
     def sync(self):
-
         self.logger.debug(
             "Inquiring about {} celery task(s)".format(len(self.tasks)))
-        for key, async in list(self.tasks.items()):
+        for key, uuid in list(self.tasks.items()):
+            async = AsyncResult(id=uuid, app=app)
             try:
                 state = async.state
-                if self.last_state[key] != state:
-                    if state == celery_states.SUCCESS:
-                        self.success(key)
-                        del self.tasks[key]
-                        del self.last_state[key]
-                    elif state == celery_states.FAILURE:
-                        self.fail(key)
-                        del self.tasks[key]
-                        del self.last_state[key]
-                    elif state == celery_states.REVOKED:
-                        self.fail(key)
-                        del self.tasks[key]
-                        del self.last_state[key]
-                    else:
-                        self.logger.info("Unexpected state: " + async.state)
-                    self.last_state[key] = async.state
+                if state == celery_states.SUCCESS:
+                    self.success(key)
+                    del self.tasks[key]
+                elif state == celery_states.FAILURE:
+                    self.fail(key)
+                    del self.tasks[key]
+                elif state == celery_states.REVOKED:
+                    self.fail(key)
+                    del self.tasks[key]
+                else:
+                    self.logger.warning("Unexpected state: " + async.state)
             except Exception as e:
                 logging.error("Error syncing the celery executor, ignoring "
                               "it:\n{}\n".format(e, traceback.format_exc()))
 
     def end(self, synchronous=False):
         if synchronous:
+            tasks = []
+            for uuid in self.tasks.values():
+                tasks.append(AsyncResult(id=uuid, app=app))
+
             while any([
-                    async.state not in celery_states.READY_STATES
-                    for async in self.tasks.values()]):
+                    not task.ready()
+                    for task in tasks]):
                 time.sleep(5)
         self.sync()
+
+    def _recover_queue(self):
+        """
+        In case of a scheduler (re)start figure out if there are already tasks
+        in the queue which have been sent to the workers.
+        :return: None
+        """
+        self.logger.info("Recovering task queue")
+        for key, uuid in self.tasks.items():
+            self.logger.info("Recovering command for {}".format(key))
+            self.running[key] = True
diff --git a/airflow/migrations/versions/a7006f20d0e0_add_executor_queue.py b/airflow/migrations/versions/a7006f20d0e0_add_executor_queue.py
new file mode 100644
index 0000000000..05f2e04003
--- /dev/null
+++ b/airflow/migrations/versions/a7006f20d0e0_add_executor_queue.py
@@ -0,0 +1,42 @@
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""add executor queue
+
+Revision ID: a7006f20d0e0
+Revises: 947454bf1dff
+Create Date: 2017-08-22 10:54:07.425417
+
+"""
+
+# revision identifiers, used by Alembic.
+revision = 'a7006f20d0e0'
+down_revision = '947454bf1dff'
+branch_labels = None
+depends_on = None
+
+from alembic import op
+import sqlalchemy as sa
+
+
+def upgrade():
+    op.create_table('executor_queue',
+                    sa.Column('key', sa.String(length=250), nullable=False),
+                    sa.Column('id', sa.String(length=250), nullable=False),
+                    sa.Column('updated', sa.DateTime(), nullable=True),
+                    sa.PrimaryKeyConstraint('key')
+                    )
+
+
+def downgrade():
+    op.drop_table('executor_queue')
diff --git a/airflow/models.py b/airflow/models.py
index d83bc9a73d..a4ce009c9c 100755
--- a/airflow/models.py
+++ b/airflow/models.py
@@ -2630,6 +2630,65 @@ def xcom_pull(
             include_prior_dates=include_prior_dates)
 
 
+class ExecutorQueue(Base):
+    __tablename__ = "executor_queue"
+
+    key = Column(String(ID_LEN), primary_key=True)
+    id = Column(String(ID_LEN), nullable=False)
+    updated = Column(DateTime, nullable=False)
+
+
+class ExecutorQueueManager(dict):
+    def __init__(self, *args, **kwargs):
+        super(ExecutorQueueManager, self).__init__(*args, **kwargs)
+        self.session = settings.Session()
+
+    def __getitem__(self, key):
+        qry = (self.session.query(ExecutorQueue)
+               .filter(ExecutorQueue.key == key)
+               .value(ExecutorQueue.id))
+
+        return qry
+
+    def __setitem__(self, key, value):
+        """
+        :param key: task_id
+        :param value: AsyncResult
+        :return:
+        """
+        task = ExecutorQueue()
+        task.key = key
+        task.id = value.id
+        task.updated = datetime.utcnow()
+        self.session.merge(task)
+        self.session.commit()
+
+    def __contains__(self, item):
+        if self.__getitem__(item):
+            return True
+
+        return False
+
+    def __iter__(self):
+        return self.session.query(ExecutorQueue).values(ExecutorQueue.key,
+                                                        ExecutorQueue.id)
+
+    def __delitem__(self, key):
+        self.session.query(ExecutorQueue).filter(ExecutorQueue.key == key).delete()
+        self.session.commit()
+
+    def iteritems(self):
+        return self.__iter__()
+
+    def items(self):
+        return list(self.__iter__())
+
+    def values(self):
+        qry = self.session.query(ExecutorQueue).values(ExecutorQueue.id)
+        l = [val[0] for val in qry]
+        return l
+
+
 class DagModel(Base):
 
     __tablename__ = "dag"
diff --git a/scripts/ci/airflow_travis.cfg b/scripts/ci/airflow_travis.cfg
index 01bf3a47cd..c0688aa45d 100644
--- a/scripts/ci/airflow_travis.cfg
+++ b/scripts/ci/airflow_travis.cfg
@@ -44,8 +44,8 @@ smtp_mail_from = airflow@example.com
 celery_app_name = airflow.executors.celery_executor
 celeryd_concurrency = 16
 worker_log_server_port = 8793
-broker_url = sqla+mysql://airflow:airflow@localhost:3306/airflow
-celery_result_backend = db+mysql://airflow:airflow@localhost:3306/airflow
+broker_url = sqla+mysql://root@localhost:3306/airflow
+celery_result_backend = db+mysql://root@localhost:3306/airflow
 flower_port = 5555
 default_queue = default
 
diff --git a/tests/executors/test_celery_executor.py b/tests/executors/test_celery_executor.py
new file mode 100644
index 0000000000..ad99376d1f
--- /dev/null
+++ b/tests/executors/test_celery_executor.py
@@ -0,0 +1,65 @@
+# -*- coding: utf-8 -*-
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import unittest
+
+from airflow.executors.celery_executor import app
+from airflow.executors.celery_executor import CeleryExecutor
+from airflow.utils.state import State
+from celery.contrib.testing.worker import start_worker
+
+# leave this it is used by the test worker
+import celery.contrib.testing.tasks
+
+
+class CeleryExecutorTest(unittest.TestCase):
+    def test_celery_recovery(self):
+        with start_worker(app=app):
+            executor = CeleryExecutor()
+            executor.start()
+
+            delay_command = 'echo test && sleep 2'
+
+            executor.execute_async(key='delay', command=delay_command)
+            executor.running['delay'] = True
+            executor.sync()
+
+            del executor.running['delay']
+            executor.start()
+            self.assertTrue(executor.running['delay'])
+
+            executor.end(synchronous=True)
+            self.assertTrue(executor.event_buffer['delay'], State.SUCCESS)
+
+            # needs to be combined
+            executor.start()
+
+            success_command = 'echo 1'
+            fail_command = 'exit 1'
+
+            executor.execute_async(key='success', command=success_command)
+            # errors are propagated fpr some reason
+            try:
+                executor.execute_async(key='fail', command=fail_command)
+            except:
+                pass
+            executor.running['success'] = True
+            executor.running['fail'] = True
+
+            executor.end(synchronous=True)
+
+            self.assertTrue(executor.event_buffer['success'], State.SUCCESS)
+            self.assertTrue(executor.event_buffer['fail'], State.FAILED)
+
+            self.assertIsNone(executor.tasks['success'])
+            self.assertIsNone(executor.tasks['fail'])


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services