You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by sa...@apache.org on 2017/09/06 20:49:37 UTC

incubator-airflow git commit: [AIRFLOW-1493][AIRFLOW-XXXX][WIP] fixed dumb thing

Repository: incubator-airflow
Updated Branches:
  refs/heads/master af91e2ac0 -> b2e1753f5


[AIRFLOW-1493][AIRFLOW-XXXX][WIP] fixed dumb thing

Closes #2505 from saguziel/aguziel-fix-double-
trigger


Project: http://git-wip-us.apache.org/repos/asf/incubator-airflow/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-airflow/commit/b2e1753f
Tree: http://git-wip-us.apache.org/repos/asf/incubator-airflow/tree/b2e1753f
Diff: http://git-wip-us.apache.org/repos/asf/incubator-airflow/diff/b2e1753f

Branch: refs/heads/master
Commit: b2e1753f5b74ad1b6e0889f7b784ce69623c95ce
Parents: af91e2a
Author: Alex Guziel <al...@airbnb.com>
Authored: Wed Sep 6 13:49:34 2017 -0700
Committer: Alex Guziel <al...@airbnb.com>
Committed: Wed Sep 6 13:49:34 2017 -0700

----------------------------------------------------------------------
 airflow/bin/cli.py                      |  6 +-
 airflow/jobs.py                         | 47 +++++++-------
 airflow/models.py                       | 95 +++++++++++++++++++++++++---
 airflow/task_runner/base_task_runner.py |  3 -
 tests/dags/test_mark_success.py         | 30 +++++++++
 tests/jobs.py                           | 58 +++++++++++++++--
 tests/models.py                         | 18 +++++-
 7 files changed, 208 insertions(+), 49 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/b2e1753f/airflow/bin/cli.py
----------------------------------------------------------------------
diff --git a/airflow/bin/cli.py b/airflow/bin/cli.py
index e9e60cb..a0545c3 100755
--- a/airflow/bin/cli.py
+++ b/airflow/bin/cli.py
@@ -381,12 +381,8 @@ def run(args, dag=None):
             pool=args.pool)
         run_job.run()
     elif args.raw:
-        ti.run(
+        ti._run_raw_task(
             mark_success=args.mark_success,
-            ignore_all_deps=args.ignore_all_dependencies,
-            ignore_depends_on_past=args.ignore_depends_on_past,
-            ignore_task_deps=args.ignore_dependencies,
-            ignore_ti_state=args.force,
             job_id=args.job_id,
             pool=args.pool,
         )

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/b2e1753f/airflow/jobs.py
----------------------------------------------------------------------
diff --git a/airflow/jobs.py b/airflow/jobs.py
index d94a0e0..138a055 100644
--- a/airflow/jobs.py
+++ b/airflow/jobs.py
@@ -2448,10 +2448,6 @@ class LocalTaskJob(BaseJob):
         # terminate multiple times
         self.terminating = False
 
-        # Keeps track of the fact that the task instance has been observed
-        # as running at least once
-        self.was_running = False
-
         super(LocalTaskJob, self).__init__(*args, **kwargs)
 
     def _execute(self):
@@ -2464,6 +2460,17 @@ class LocalTaskJob(BaseJob):
             raise AirflowException("LocalTaskJob received SIGTERM signal")
         signal.signal(signal.SIGTERM, signal_handler)
 
+        if not self.task_instance._check_and_change_state_before_execution(
+                mark_success=self.mark_success,
+                ignore_all_deps=self.ignore_all_deps,
+                ignore_depends_on_past=self.ignore_depends_on_past,
+                ignore_task_deps=self.ignore_task_deps,
+                ignore_ti_state=self.ignore_ti_state,
+                job_id=self.id,
+                pool=self.pool):
+            self.logger.info("Task is not able to be run") 
+            return 
+
         try:
             self.task_runner.start()
 
@@ -2506,44 +2513,34 @@ class LocalTaskJob(BaseJob):
         self.task_runner.terminate()
         self.task_runner.on_finish()
 
-    def _is_descendant_process(self, pid):
-        """Checks if pid is a descendant of the current process.
-
-        :param pid: process id to check
-        :type pid: int
-        :rtype: bool
-        """
-        try:
-            return psutil.Process(pid) in psutil.Process().children(recursive=True)
-        except psutil.NoSuchProcess:
-            return False
-
     @provide_session
     def heartbeat_callback(self, session=None):
         """Self destruct task if state has been moved away from running externally"""
 
         if self.terminating:
-            # task is already terminating, let it breathe
+            # ensure termination if processes are created later
+            self.task_runner.terminate()
             return
 
         self.task_instance.refresh_from_db()
         ti = self.task_instance
+
+        fqdn = socket.getfqdn()
+        same_hostname = fqdn == ti.hostname
+        same_process = ti.pid == os.getpid()
+
         if ti.state == State.RUNNING:
-            self.was_running = True
-            fqdn = socket.getfqdn()
-            if fqdn != ti.hostname:
+            if not same_hostname:
                 logging.warning("The recorded hostname {ti.hostname} "
                                 "does not match this instance's hostname "
                                 "{fqdn}".format(**locals()))
                 raise AirflowException("Hostname of job runner does not match")
-            elif not self._is_descendant_process(ti.pid):
+            elif not same_process:
                 current_pid = os.getpid()
-                logging.warning("Recorded pid {ti.pid} is not a "
-                                "descendant of the current pid "
+                logging.warning("Recorded pid {ti.pid} does not match the current pid "
                                 "{current_pid}".format(**locals()))
                 raise AirflowException("PID of job runner does not match")
-        elif (self.was_running
-              and self.task_runner.return_code() is None
+        elif (self.task_runner.return_code() is None
               and hasattr(self.task_runner, 'process')):
             logging.warning(
                 "State of this instance has been externally set to "

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/b2e1753f/airflow/models.py
----------------------------------------------------------------------
diff --git a/airflow/models.py b/airflow/models.py
index d83bc9a..3078f4e 100755
--- a/airflow/models.py
+++ b/airflow/models.py
@@ -1275,7 +1275,7 @@ class TaskInstance(Base):
         return dr
 
     @provide_session
-    def run(
+    def _check_and_change_state_before_execution(
             self,
             verbose=True,
             ignore_all_deps=False,
@@ -1288,7 +1288,9 @@ class TaskInstance(Base):
             pool=None,
             session=None):
         """
-        Runs the task instance.
+        Checks dependencies and then sets state to RUNNING if they are met. Returns
+        True if and only if state is set to RUNNING, which implies that task should be
+        executed, in preparation for _run_raw_task
 
         :param verbose: whether to turn on more verbose logging
         :type verbose: boolean
@@ -1306,6 +1308,8 @@ class TaskInstance(Base):
         :type test_mode: boolean
         :param pool: specifies the pool to use to run the task instance
         :type pool: str
+        :return: whether the state was changed to running or not
+        :rtype: bool
         """
         task = self.task
         self.pool = pool or task.pool
@@ -1329,7 +1333,7 @@ class TaskInstance(Base):
                 session=session,
                 verbose=True):
             session.commit()
-            return
+            return False
 
         hr = "\n" + ("-" * 80) + "\n"  # Line break
 
@@ -1368,7 +1372,7 @@ class TaskInstance(Base):
             logging.info(msg)
             session.merge(self)
             session.commit()
-            return
+            return False
 
         # Another worker might have started running this task instance while
         # the current worker process was blocked on refresh_from_db
@@ -1376,7 +1380,7 @@ class TaskInstance(Base):
             msg = "Task Instance already running {}".format(self)
             logging.warning(msg)
             session.commit()
-            return
+            return False
 
         # print status message
         logging.info(hr + msg + hr)
@@ -1396,14 +1400,44 @@ class TaskInstance(Base):
         settings.engine.dispose()
         if verbose:
             if mark_success:
-                msg = "Marking success for "
+                msg = "Marking success for {} on {}".format(self.task, self.execution_date)
+                logging.info(msg)
             else:
-                msg = "Executing "
-            msg += "{self.task} on {self.execution_date}"
+                msg = "Executing {} on {}".format(self.task, self.execution_date)
+                logging.info(msg)
+        return True
+
+    @provide_session
+    def _run_raw_task(
+            self,
+            mark_success=False,
+            test_mode=False,
+            job_id=None,
+            pool=None,
+            session=None):
+        """
+        Immediately runs the task (without checking or changing db state
+        before execution) and then sets the appropriate final state after
+        completion and runs any post-execute callbacks. Meant to be called
+        only after another function changes the state to running.
+
+        :param mark_success: Don't run the task, mark its state as success
+        :type mark_success: boolean
+        :param test_mode: Doesn't record success or failure in the DB
+        :type test_mode: boolean
+        :param pool: specifies the pool to use to run the task instance
+        :type pool: str
+        """
+        task = self.task
+        self.pool = pool or task.pool
+        self.test_mode = test_mode
+        self.refresh_from_db(session=session)
+        self.job_id = job_id
+        self.hostname = socket.getfqdn()
+        self.operator = task.__class__.__name__
 
         context = {}
         try:
-            logging.info(msg.format(self=self))
             if not mark_success:
                 context = self.get_template_context()
 
@@ -1460,9 +1494,20 @@ class TaskInstance(Base):
                 Stats.incr('operator_successes_{}'.format(
                     self.task.__class__.__name__), 1, 1)
                 Stats.incr('ti_successes')
+            self.refresh_from_db(lock_for_update=True)
             self.state = State.SUCCESS
         except AirflowSkipException:
+            self.refresh_from_db(lock_for_update=True)
             self.state = State.SKIPPED
+        except AirflowException as e:
+            self.refresh_from_db()
+            # for case when task is marked as success externally
+            # current behavior doesn't hit the success callback
+            if self.state == State.SUCCESS:
+                return
+            else:
+                self.handle_failure(e, test_mode, context)
+                raise
         except (Exception, KeyboardInterrupt) as e:
             self.handle_failure(e, test_mode, context)
             raise
@@ -1485,6 +1530,38 @@ class TaskInstance(Base):
 
         session.commit()
 
+    @provide_session
+    def run(
+            self,
+            verbose=True,
+            ignore_all_deps=False,
+            ignore_depends_on_past=False,
+            ignore_task_deps=False,
+            ignore_ti_state=False,
+            mark_success=False,
+            test_mode=False,
+            job_id=None,
+            pool=None,
+            session=None):
+        res = self._check_and_change_state_before_execution(
+                verbose=verbose,
+                ignore_all_deps=ignore_all_deps,
+                ignore_depends_on_past=ignore_depends_on_past,
+                ignore_task_deps=ignore_task_deps,
+                ignore_ti_state=ignore_ti_state,
+                mark_success=mark_success,
+                test_mode=test_mode,
+                job_id=job_id,
+                pool=pool,
+                session=session)
+        if res:
+            self._run_raw_task(
+                    mark_success=mark_success,
+                    test_mode=test_mode,
+                    job_id=job_id,
+                    pool=pool,
+                    session=session)
+
     def dry_run(self):
         task = self.task
         task_copy = copy.copy(task)

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/b2e1753f/airflow/task_runner/base_task_runner.py
----------------------------------------------------------------------
diff --git a/airflow/task_runner/base_task_runner.py b/airflow/task_runner/base_task_runner.py
index bed8eaa..8ca8f1a 100644
--- a/airflow/task_runner/base_task_runner.py
+++ b/airflow/task_runner/base_task_runner.py
@@ -79,9 +79,6 @@ class BaseTaskRunner(LoggingMixin):
         self._cfg_path = cfg_path
         self._command = popen_prepend + self._task_instance.command_as_list(
             raw=True,
-            ignore_all_deps=local_task_job.ignore_all_deps,
-            ignore_depends_on_past=local_task_job.ignore_depends_on_past,
-            ignore_ti_state=local_task_job.ignore_ti_state,
             pickle_id=local_task_job.pickle_id,
             mark_success=local_task_job.mark_success,
             job_id=local_task_job.id,

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/b2e1753f/tests/dags/test_mark_success.py
----------------------------------------------------------------------
diff --git a/tests/dags/test_mark_success.py b/tests/dags/test_mark_success.py
new file mode 100644
index 0000000..56a5662
--- /dev/null
+++ b/tests/dags/test_mark_success.py
@@ -0,0 +1,30 @@
+# -*- coding: utf-8 -*-
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from datetime import datetime
+
+from airflow.models import DAG
+from airflow.operators.bash_operator import BashOperator
+
+DEFAULT_DATE = datetime(2016, 1, 1)
+
+args = {
+    'owner': 'airflow',
+    'start_date': DEFAULT_DATE,
+}
+
+dag = DAG(dag_id='test_mark_success', default_args=args)
+task = BashOperator(
+    task_id='task1',
+    bash_command='sleep 600',
+    dag=dag)

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/b2e1753f/tests/jobs.py
----------------------------------------------------------------------
diff --git a/tests/jobs.py b/tests/jobs.py
index 1c0b5cc..3039e38 100644
--- a/tests/jobs.py
+++ b/tests/jobs.py
@@ -19,12 +19,14 @@ from __future__ import unicode_literals
 
 import datetime
 import logging
+import multiprocessing
 import os
 import shutil
-import unittest
 import six
 import socket
 import threading
+import time
+import unittest
 from tempfile import mkdtemp
 
 from airflow import AirflowException, settings, models
@@ -34,6 +36,7 @@ from airflow.jobs import BackfillJob, SchedulerJob, LocalTaskJob
 from airflow.models import DAG, DagModel, DagBag, DagRun, Pool, TaskInstance as TI
 from airflow.operators.dummy_operator import DummyOperator
 from airflow.operators.bash_operator import BashOperator
+from airflow.task_runner.base_task_runner import BaseTaskRunner
 from airflow.utils.db import provide_session
 from airflow.utils.state import State
 from airflow.utils.timeout import timeout
@@ -729,8 +732,8 @@ class LocalTaskJobTest(unittest.TestCase):
     def setUp(self):
         pass
 
-    @patch.object(LocalTaskJob, "_is_descendant_process")
-    def test_localtaskjob_heartbeat(self, is_descendant):
+    @patch('os.getpid')
+    def test_localtaskjob_heartbeat(self, mock_pid):
         session = settings.Session()
         dag = DAG(
             'test_localtaskjob_heartbeat',
@@ -756,7 +759,7 @@ class LocalTaskJobTest(unittest.TestCase):
                             executor=SequentialExecutor())
         self.assertRaises(AirflowException, job1.heartbeat_callback)
 
-        is_descendant.return_value = True
+        mock_pid.return_value = 1
         ti.state = State.RUNNING
         ti.hostname = socket.getfqdn()
         ti.pid = 1
@@ -766,9 +769,50 @@ class LocalTaskJobTest(unittest.TestCase):
         ret = job1.heartbeat_callback()
         self.assertEqual(ret, None)
 
-        is_descendant.return_value = False
+        mock_pid.return_value = 2
         self.assertRaises(AirflowException, job1.heartbeat_callback)
 
+    def test_mark_success_no_kill(self):
+        """
+        Test that ensures that mark_success in the UI doesn't cause
+        the task to fail, and that the task exits
+        """
+        dagbag = models.DagBag(
+            dag_folder=TEST_DAG_FOLDER,
+            include_examples=False,
+        )
+        dag = dagbag.dags.get('test_mark_success')
+        task = dag.get_task('task1')
+
+        session = settings.Session()
+
+        dag.clear()
+        dr = dag.create_dagrun(run_id="test",
+                               state=State.RUNNING,
+                               execution_date=DEFAULT_DATE,
+                               start_date=DEFAULT_DATE,
+                               session=session)
+        ti = TI(task=task, execution_date=DEFAULT_DATE)
+        ti.refresh_from_db()
+        job1 = LocalTaskJob(task_instance=ti, ignore_ti_state=True)
+        process = multiprocessing.Process(target=job1.run)
+        process.start()
+        ti.refresh_from_db()
+        for i in range(0, 50):
+            if ti.state == State.RUNNING:
+                break
+            time.sleep(0.1)
+            ti.refresh_from_db()
+        self.assertEqual(State.RUNNING, ti.state)
+        ti.state = State.SUCCESS
+        session.merge(ti)
+        session.commit()
+
+        process.join(timeout=5)
+        self.assertFalse(process.is_alive())
+        ti.refresh_from_db()
+        self.assertEqual(State.SUCCESS, ti.state)
+
     def test_localtaskjob_double_trigger(self):
         dagbag = models.DagBag(
             dag_folder=TEST_DAG_FOLDER,
@@ -795,7 +839,9 @@ class LocalTaskJobTest(unittest.TestCase):
         job1 = LocalTaskJob(task_instance=ti_run,
                             ignore_ti_state=True,
                             executor=SequentialExecutor())
-        self.assertRaises(AirflowException, job1.run)
+        with patch.object(BaseTaskRunner, 'start', return_value=None) as mock_method:
+            job1.run()
+            mock_method.assert_not_called()
 
         ti = dr.get_task_instance(task_id=task.task_id, session=session)
         self.assertEqual(ti.pid, 1)

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/b2e1753f/tests/models.py
----------------------------------------------------------------------
diff --git a/tests/models.py b/tests/models.py
index 96275d3..db5beca 100644
--- a/tests/models.py
+++ b/tests/models.py
@@ -865,7 +865,7 @@ class TaskInstanceTest(unittest.TestCase):
         ti = TI(
             task=task, execution_date=datetime.datetime.now())
         ti.run()
-        self.assertTrue(ti.state == models.State.SKIPPED)
+        self.assertEqual(models.State.SKIPPED, ti.state)
 
     def test_retry_delay(self):
         """
@@ -1186,6 +1186,22 @@ class TaskInstanceTest(unittest.TestCase):
         with self.assertRaises(TestError):
             ti.run()
 
+    def test_check_and_change_state_before_execution(self):
+        dag = models.DAG(dag_id='test_check_and_change_state_before_execution')
+        task = DummyOperator(task_id='task', dag=dag, start_date=DEFAULT_DATE)
+        ti = TI(
+            task=task, execution_date=datetime.datetime.now())
+        self.assertTrue(ti._check_and_change_state_before_execution())
+
+    def test_check_and_change_state_before_execution_dep_not_met(self):
+        dag = models.DAG(dag_id='test_check_and_change_state_before_execution')
+        task = DummyOperator(task_id='task', dag=dag, start_date=DEFAULT_DATE)
+        task2= DummyOperator(task_id='task2', dag=dag, start_date=DEFAULT_DATE)
+        task >> task2
+        ti = TI(
+            task=task2, execution_date=datetime.datetime.now())
+        self.assertFalse(ti._check_and_change_state_before_execution())
+        
 
 class ClearTasksTest(unittest.TestCase):
     def test_clear_task_instances(self):