You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by bo...@apache.org on 2017/12/07 13:33:11 UTC

incubator-airflow git commit: [AIRFLOW-1873] Set TI.try_number to right value depending TI state

Repository: incubator-airflow
Updated Branches:
  refs/heads/master ad4f75111 -> 4b4e504ee


[AIRFLOW-1873] Set TI.try_number to right value depending TI state

Rather than having try_number+1 in various places,
try_number
will now automatically contain the right value for
when the TI
will next be run, and handle the case where
try_number is
accessed when the task is currently running.

This showed up as a bug where the logs from
running operators would
show up in the next log file (2.log for the first
try)

Closes #2832 from ashb/AIRFLOW-1873-task-operator-
log-try-number


Project: http://git-wip-us.apache.org/repos/asf/incubator-airflow/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-airflow/commit/4b4e504e
Tree: http://git-wip-us.apache.org/repos/asf/incubator-airflow/tree/4b4e504e
Diff: http://git-wip-us.apache.org/repos/asf/incubator-airflow/diff/4b4e504e

Branch: refs/heads/master
Commit: 4b4e504eeae81e48f3c9d796a61dd9e86000c663
Parents: ad4f751
Author: Ash Berlin-Taylor <as...@firemirror.com>
Authored: Thu Dec 7 13:31:38 2017 +0000
Committer: Bolke de Bruin <bo...@xs4all.nl>
Committed: Thu Dec 7 13:31:46 2017 +0000

----------------------------------------------------------------------
 airflow/models.py                      | 44 +++++++++++----
 airflow/utils/log/file_task_handler.py |  8 +--
 airflow/utils/log/gcs_task_handler.py  |  4 +-
 airflow/utils/log/s3_task_handler.py   |  4 +-
 tests/jobs.py                          |  3 +-
 tests/models.py                        | 87 +++++++++++++++++++----------
 tests/utils/test_log_handlers.py       | 53 +++++++++++++++---
 7 files changed, 144 insertions(+), 59 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/4b4e504e/airflow/models.py
----------------------------------------------------------------------
diff --git a/airflow/models.py b/airflow/models.py
index 5837363..16fae10 100755
--- a/airflow/models.py
+++ b/airflow/models.py
@@ -137,13 +137,13 @@ def clear_task_instances(tis, session, activate_dag_runs=True, dag=None):
             if dag and dag.has_task(task_id):
                 task = dag.get_task(task_id)
                 task_retries = task.retries
-                ti.max_tries = ti.try_number + task_retries
+                ti.max_tries = ti.try_number + task_retries - 1
             else:
                 # Ignore errors when updating max_tries if dag is None or
                 # task not found in dag since database records could be
                 # outdated. We make max_tries the maximum value of its
                 # original max_tries or the current task try number.
-                ti.max_tries = max(ti.max_tries, ti.try_number)
+                ti.max_tries = max(ti.max_tries, ti.try_number - 1)
             ti.state = State.NONE
             session.merge(ti)
 
@@ -771,7 +771,7 @@ class TaskInstance(Base, LoggingMixin):
     end_date = Column(UtcDateTime)
     duration = Column(Float)
     state = Column(String(20))
-    try_number = Column(Integer, default=0)
+    _try_number = Column('try_number', Integer, default=0)
     max_tries = Column(Integer)
     hostname = Column(String(1000))
     unixname = Column(String(1000))
@@ -813,6 +813,24 @@ class TaskInstance(Base, LoggingMixin):
         """ Initialize the attributes that aren't stored in the DB. """
         self.test_mode = False  # can be changed when calling 'run'
 
+    @property
+    def try_number(self):
+        """
+        Return the try number that this task number will be when it is acutally
+        run.
+
+        If the TI is currently running, this will match the column in the
+        databse, in all othercases this will be incremenetd
+        """
+        # This is designed so that task logs end up in the right file.
+        if self.state == State.RUNNING:
+            return self._try_number
+        return self._try_number + 1
+
+    @try_number.setter
+    def try_number(self, value):
+        self._try_number = value
+
     def command(
             self,
             mark_success=False,
@@ -1041,7 +1059,9 @@ class TaskInstance(Base, LoggingMixin):
             self.state = ti.state
             self.start_date = ti.start_date
             self.end_date = ti.end_date
-            self.try_number = ti.try_number
+            # Get the raw value of try_number column, don't read through the
+            # accessor here otherwise it will be incremeneted by one already.
+            self.try_number = ti._try_number
             self.max_tries = ti.max_tries
             self.hostname = ti.hostname
             self.pid = ti.pid
@@ -1342,7 +1362,7 @@ class TaskInstance(Base, LoggingMixin):
         # not 0-indexed lists (i.e. Attempt 1 instead of
         # Attempt 0 for the first attempt).
         msg = "Starting attempt {attempt} of {total}".format(
-            attempt=self.try_number + 1,
+            attempt=self.try_number,
             total=self.max_tries + 1)
         self.start_date = timezone.utcnow()
 
@@ -1364,7 +1384,7 @@ class TaskInstance(Base, LoggingMixin):
             self.state = State.NONE
             msg = ("FIXME: Rescheduling due to concurrency limits reached at task "
                    "runtime. Attempt {attempt} of {total}. State set to NONE.").format(
-                attempt=self.try_number + 1,
+                attempt=self.try_number,
                 total=self.max_tries + 1)
             self.log.warning(hr + msg + hr)
 
@@ -1384,7 +1404,7 @@ class TaskInstance(Base, LoggingMixin):
 
         # print status message
         self.log.info(hr + msg + hr)
-        self.try_number += 1
+        self._try_number += 1
 
         if not test_mode:
             session.add(Log(State.RUNNING, self))
@@ -1586,10 +1606,10 @@ class TaskInstance(Base, LoggingMixin):
 
         # Let's go deeper
         try:
-            # try_number is incremented by 1 during task instance run. So the
-            # current task instance try_number is the try_number for the next
-            # task instance run. We only mark task instance as FAILED if the
-            # next task instance try_number exceeds the max_tries.
+            # Since this function is called only when the TI state is running,
+            # try_number contains the current try_number (not the next). We
+            # only mark task instance as FAILED if the next task instance
+            # try_number exceeds the max_tries.
             if task.retries and self.try_number <= self.max_tries:
                 self.state = State.UP_FOR_RETRY
                 self.log.info('Marking task as UP_FOR_RETRY')
@@ -1754,7 +1774,7 @@ class TaskInstance(Base, LoggingMixin):
             "Host: {self.hostname}<br>"
             "Log file: {self.log_filepath}<br>"
             "Mark success: <a href='{self.mark_success_url}'>Link</a><br>"
-        ).format(try_number=self.try_number + 1, max_tries=self.max_tries + 1, **locals())
+        ).format(try_number=self.try_number, max_tries=self.max_tries + 1, **locals())
         send_email(task.email, title, body)
 
     def set_duration(self):

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/4b4e504e/airflow/utils/log/file_task_handler.py
----------------------------------------------------------------------
diff --git a/airflow/utils/log/file_task_handler.py b/airflow/utils/log/file_task_handler.py
index f131c09..82fc349 100644
--- a/airflow/utils/log/file_task_handler.py
+++ b/airflow/utils/log/file_task_handler.py
@@ -89,7 +89,7 @@ class FileTaskHandler(logging.Handler):
         # Task instance here might be different from task instance when
         # initializing the handler. Thus explicitly getting log location
         # is needed to get correct log path.
-        log_relative_path = self._render_filename(ti, try_number + 1)
+        log_relative_path = self._render_filename(ti, try_number)
         location = os.path.join(self.local_base, log_relative_path)
 
         log = ""
@@ -144,8 +144,8 @@ class FileTaskHandler(logging.Handler):
         next_try = task_instance.try_number
 
         if try_number is None:
-            try_numbers = list(range(next_try))
-        elif try_number < 0:
+            try_numbers = list(range(1, next_try))
+        elif try_number < 1:
             logs = ['Error fetching the logs. Try number {} is invalid.'.format(try_number)]
             return logs
         else:
@@ -176,7 +176,7 @@ class FileTaskHandler(logging.Handler):
         # writable by both users, then it's possible that re-running a task
         # via the UI (or vice versa) results in a permission error as the task
         # tries to write to a log file created by the other user.
-        relative_path = self._render_filename(ti, ti.try_number + 1)
+        relative_path = self._render_filename(ti, ti.try_number)
         full_path = os.path.join(self.local_base, relative_path)
         directory = os.path.dirname(full_path)
         # Create the log file and give it group writable permissions

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/4b4e504e/airflow/utils/log/gcs_task_handler.py
----------------------------------------------------------------------
diff --git a/airflow/utils/log/gcs_task_handler.py b/airflow/utils/log/gcs_task_handler.py
index b556cf0..fea5acc 100644
--- a/airflow/utils/log/gcs_task_handler.py
+++ b/airflow/utils/log/gcs_task_handler.py
@@ -58,7 +58,7 @@ class GCSTaskHandler(FileTaskHandler, LoggingMixin):
         # Log relative path is used to construct local and remote
         # log path to upload log files into GCS and read from the
         # remote location.
-        self.log_relative_path = self._render_filename(ti, ti.try_number + 1)
+        self.log_relative_path = self._render_filename(ti, ti.try_number)
 
     def close(self):
         """
@@ -94,7 +94,7 @@ class GCSTaskHandler(FileTaskHandler, LoggingMixin):
         # Explicitly getting log relative path is necessary as the given
         # task instance might be different than task instance passed in
         # in set_context method.
-        log_relative_path = self._render_filename(ti, try_number + 1)
+        log_relative_path = self._render_filename(ti, try_number)
         remote_loc = os.path.join(self.remote_base, log_relative_path)
 
         try:

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/4b4e504e/airflow/utils/log/s3_task_handler.py
----------------------------------------------------------------------
diff --git a/airflow/utils/log/s3_task_handler.py b/airflow/utils/log/s3_task_handler.py
index cfa966a..5ff90c6 100644
--- a/airflow/utils/log/s3_task_handler.py
+++ b/airflow/utils/log/s3_task_handler.py
@@ -53,7 +53,7 @@ class S3TaskHandler(FileTaskHandler, LoggingMixin):
         super(S3TaskHandler, self).set_context(ti)
         # Local location and remote location is needed to open and
         # upload local log file to S3 remote storage.
-        self.log_relative_path = self._render_filename(ti, ti.try_number + 1)
+        self.log_relative_path = self._render_filename(ti, ti.try_number)
 
     def close(self):
         """
@@ -89,7 +89,7 @@ class S3TaskHandler(FileTaskHandler, LoggingMixin):
         # Explicitly getting log relative path is necessary as the given
         # task instance might be different than task instance passed in
         # in set_context method.
-        log_relative_path = self._render_filename(ti, try_number + 1)
+        log_relative_path = self._render_filename(ti, try_number)
         remote_loc = os.path.join(self.remote_base, log_relative_path)
 
         if self.s3_log_exists(remote_loc):

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/4b4e504e/tests/jobs.py
----------------------------------------------------------------------
diff --git a/tests/jobs.py b/tests/jobs.py
index ca2db2c..77d872f 100644
--- a/tests/jobs.py
+++ b/tests/jobs.py
@@ -2404,10 +2404,11 @@ class SchedulerJobTest(unittest.TestCase):
         (command, priority, queue, ti) = ti_tuple
         ti.task = dag_task1
 
+        self.assertEqual(ti.try_number, 1)
         # fail execution
         run_with_error(ti)
         self.assertEqual(ti.state, State.UP_FOR_RETRY)
-        self.assertEqual(ti.try_number, 1)
+        self.assertEqual(ti.try_number, 2)
 
         ti.refresh_from_db(lock_for_update=True, session=session)
         ti.state = State.SCHEDULED

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/4b4e504e/tests/models.py
----------------------------------------------------------------------
diff --git a/tests/models.py b/tests/models.py
index cabcf3a..a2433ab 100644
--- a/tests/models.py
+++ b/tests/models.py
@@ -925,10 +925,11 @@ class TaskInstanceTest(unittest.TestCase):
         ti = TI(
             task=task, execution_date=timezone.utcnow())
 
+        self.assertEqual(ti.try_number, 1)
         # first run -- up for retry
         run_with_error(ti)
         self.assertEqual(ti.state, State.UP_FOR_RETRY)
-        self.assertEqual(ti.try_number, 1)
+        self.assertEqual(ti.try_number, 2)
 
         # second run -- still up for retry because retry_delay hasn't expired
         run_with_error(ti)
@@ -965,16 +966,19 @@ class TaskInstanceTest(unittest.TestCase):
 
         ti = TI(
             task=task, execution_date=timezone.utcnow())
+        self.assertEqual(ti.try_number, 1)
 
         # first run -- up for retry
         run_with_error(ti)
         self.assertEqual(ti.state, State.UP_FOR_RETRY)
-        self.assertEqual(ti.try_number, 1)
+        self.assertEqual(ti._try_number, 1)
+        self.assertEqual(ti.try_number, 2)
 
         # second run -- fail
         run_with_error(ti)
         self.assertEqual(ti.state, State.FAILED)
-        self.assertEqual(ti.try_number, 2)
+        self.assertEqual(ti._try_number, 2)
+        self.assertEqual(ti.try_number, 3)
 
         # Clear the TI state since you can't run a task with a FAILED state without
         # clearing it first
@@ -983,12 +987,15 @@ class TaskInstanceTest(unittest.TestCase):
         # third run -- up for retry
         run_with_error(ti)
         self.assertEqual(ti.state, State.UP_FOR_RETRY)
-        self.assertEqual(ti.try_number, 3)
+        self.assertEqual(ti._try_number, 3)
+        self.assertEqual(ti.try_number, 4)
 
         # fourth run -- fail
         run_with_error(ti)
+        ti.refresh_from_db()
         self.assertEqual(ti.state, State.FAILED)
-        self.assertEqual(ti.try_number, 4)
+        self.assertEqual(ti._try_number, 4)
+        self.assertEqual(ti.try_number, 5)
 
     def test_next_retry_datetime(self):
         delay = datetime.timedelta(seconds=30)
@@ -1009,19 +1016,18 @@ class TaskInstanceTest(unittest.TestCase):
             task=task, execution_date=DEFAULT_DATE)
         ti.end_date = pendulum.instance(timezone.utcnow())
 
-        ti.try_number = 1
         dt = ti.next_retry_datetime()
         # between 30 * 2^0.5 and 30 * 2^1 (15 and 30)
         period = ti.end_date.add(seconds=30) - ti.end_date.add(seconds=15)
         self.assertTrue(dt in period)
 
-        ti.try_number = 4
+        ti.try_number = 3
         dt = ti.next_retry_datetime()
         # between 30 * 2^2 and 30 * 2^3 (120 and 240)
         period = ti.end_date.add(seconds=240) - ti.end_date.add(seconds=120)
         self.assertTrue(dt in period)
 
-        ti.try_number = 6
+        ti.try_number = 5
         dt = ti.next_retry_datetime()
         # between 30 * 2^4 and 30 * 2^5 (480 and 960)
         period = ti.end_date.add(seconds=960) - ti.end_date.add(seconds=480)
@@ -1229,7 +1235,11 @@ class TaskInstanceTest(unittest.TestCase):
         task = DummyOperator(task_id='task', dag=dag, start_date=DEFAULT_DATE)
         ti = TI(
             task=task, execution_date=timezone.utcnow())
+        self.assertEqual(ti._try_number, 0)
         self.assertTrue(ti._check_and_change_state_before_execution())
+        # State should be running, and try_number column should be incremented
+        self.assertEqual(ti.state, State.RUNNING)
+        self.assertEqual(ti._try_number, 1)
 
     def test_check_and_change_state_before_execution_dep_not_met(self):
         dag = models.DAG(dag_id='test_check_and_change_state_before_execution')
@@ -1240,6 +1250,20 @@ class TaskInstanceTest(unittest.TestCase):
             task=task2, execution_date=timezone.utcnow())
         self.assertFalse(ti._check_and_change_state_before_execution())
 
+    def test_try_number(self):
+        """
+        Test the try_number accessor behaves in various running states
+        """
+        dag = models.DAG(dag_id='test_check_and_change_state_before_execution')
+        task = DummyOperator(task_id='task', dag=dag, start_date=DEFAULT_DATE)
+        ti = TI(task=task, execution_date=timezone.utcnow())
+        self.assertEqual(1, ti.try_number)
+        ti.try_number = 2
+        ti.state = State.RUNNING
+        self.assertEqual(2, ti.try_number)
+        ti.state = State.SUCCESS
+        self.assertEqual(3, ti.try_number)
+
     def test_get_num_running_task_instances(self):
         session = settings.Session()
 
@@ -1282,9 +1306,10 @@ class ClearTasksTest(unittest.TestCase):
         session.commit()
         ti0.refresh_from_db()
         ti1.refresh_from_db()
-        self.assertEqual(ti0.try_number, 1)
+        # Next try to run will be try 2
+        self.assertEqual(ti0.try_number, 2)
         self.assertEqual(ti0.max_tries, 1)
-        self.assertEqual(ti1.try_number, 1)
+        self.assertEqual(ti1.try_number, 2)
         self.assertEqual(ti1.max_tries, 3)
 
     def test_clear_task_instances_without_task(self):
@@ -1310,9 +1335,10 @@ class ClearTasksTest(unittest.TestCase):
         # When dag is None, max_tries will be maximum of original max_tries or try_number.
         ti0.refresh_from_db()
         ti1.refresh_from_db()
-        self.assertEqual(ti0.try_number, 1)
+        # Next try to run will be try 2
+        self.assertEqual(ti0.try_number, 2)
         self.assertEqual(ti0.max_tries, 1)
-        self.assertEqual(ti1.try_number, 1)
+        self.assertEqual(ti1.try_number, 2)
         self.assertEqual(ti1.max_tries, 2)
 
     def test_clear_task_instances_without_dag(self):
@@ -1333,9 +1359,10 @@ class ClearTasksTest(unittest.TestCase):
         # When dag is None, max_tries will be maximum of original max_tries or try_number.
         ti0.refresh_from_db()
         ti1.refresh_from_db()
-        self.assertEqual(ti0.try_number, 1)
+        # Next try to run will be try 2
+        self.assertEqual(ti0.try_number, 2)
         self.assertEqual(ti0.max_tries, 1)
-        self.assertEqual(ti1.try_number, 1)
+        self.assertEqual(ti1.try_number, 2)
         self.assertEqual(ti1.max_tries, 2)
 
     def test_dag_clear(self):
@@ -1343,12 +1370,13 @@ class ClearTasksTest(unittest.TestCase):
                   end_date=DEFAULT_DATE + datetime.timedelta(days=10))
         task0 = DummyOperator(task_id='test_dag_clear_task_0', owner='test', dag=dag)
         ti0 = TI(task=task0, execution_date=DEFAULT_DATE)
-        self.assertEqual(ti0.try_number, 0)
-        ti0.run()
+        # Next try to run will be try 1
         self.assertEqual(ti0.try_number, 1)
+        ti0.run()
+        self.assertEqual(ti0.try_number, 2)
         dag.clear()
         ti0.refresh_from_db()
-        self.assertEqual(ti0.try_number, 1)
+        self.assertEqual(ti0.try_number, 2)
         self.assertEqual(ti0.state, State.NONE)
         self.assertEqual(ti0.max_tries, 1)
 
@@ -1357,8 +1385,9 @@ class ClearTasksTest(unittest.TestCase):
         ti1 = TI(task=task1, execution_date=DEFAULT_DATE)
         self.assertEqual(ti1.max_tries, 2)
         ti1.try_number = 1
+        # Next try will be 2
         ti1.run()
-        self.assertEqual(ti1.try_number, 2)
+        self.assertEqual(ti1.try_number, 3)
         self.assertEqual(ti1.max_tries, 2)
 
         dag.clear()
@@ -1366,9 +1395,9 @@ class ClearTasksTest(unittest.TestCase):
         ti1.refresh_from_db()
         # after clear dag, ti2 should show attempt 3 of 5
         self.assertEqual(ti1.max_tries, 4)
-        self.assertEqual(ti1.try_number, 2)
+        self.assertEqual(ti1.try_number, 3)
         # after clear dag, ti1 should show attempt 2 of 2
-        self.assertEqual(ti0.try_number, 1)
+        self.assertEqual(ti0.try_number, 2)
         self.assertEqual(ti0.max_tries, 1)
 
     def test_dags_clear(self):
@@ -1388,7 +1417,7 @@ class ClearTasksTest(unittest.TestCase):
         for i in range(num_of_dags):
             tis[i].run()
             self.assertEqual(tis[i].state, State.SUCCESS)
-            self.assertEqual(tis[i].try_number, 1)
+            self.assertEqual(tis[i].try_number, 2)
             self.assertEqual(tis[i].max_tries, 0)
 
         DAG.clear_dags(dags)
@@ -1396,14 +1425,14 @@ class ClearTasksTest(unittest.TestCase):
         for i in range(num_of_dags):
             tis[i].refresh_from_db()
             self.assertEqual(tis[i].state, State.NONE)
-            self.assertEqual(tis[i].try_number, 1)
+            self.assertEqual(tis[i].try_number, 2)
             self.assertEqual(tis[i].max_tries, 1)
 
         # test dry_run
         for i in range(num_of_dags):
             tis[i].run()
             self.assertEqual(tis[i].state, State.SUCCESS)
-            self.assertEqual(tis[i].try_number, 2)
+            self.assertEqual(tis[i].try_number, 3)
             self.assertEqual(tis[i].max_tries, 1)
 
         DAG.clear_dags(dags, dry_run=True)
@@ -1411,7 +1440,7 @@ class ClearTasksTest(unittest.TestCase):
         for i in range(num_of_dags):
             tis[i].refresh_from_db()
             self.assertEqual(tis[i].state, State.SUCCESS)
-            self.assertEqual(tis[i].try_number, 2)
+            self.assertEqual(tis[i].try_number, 3)
             self.assertEqual(tis[i].max_tries, 1)
 
         # test only_failed
@@ -1427,11 +1456,11 @@ class ClearTasksTest(unittest.TestCase):
             tis[i].refresh_from_db()
             if i != failed_dag_idx:
                 self.assertEqual(tis[i].state, State.SUCCESS)
-                self.assertEqual(tis[i].try_number, 2)
+                self.assertEqual(tis[i].try_number, 3)
                 self.assertEqual(tis[i].max_tries, 1)
             else:
                 self.assertEqual(tis[i].state, State.NONE)
-                self.assertEqual(tis[i].try_number, 2)
+                self.assertEqual(tis[i].try_number, 3)
                 self.assertEqual(tis[i].max_tries, 2)
 
     def test_operator_clear(self):
@@ -1446,17 +1475,17 @@ class ClearTasksTest(unittest.TestCase):
         ti2 = TI(task=t2, execution_date=DEFAULT_DATE)
         ti2.run()
         # Dependency not met
-        self.assertEqual(ti2.try_number, 0)
+        self.assertEqual(ti2.try_number, 1)
         self.assertEqual(ti2.max_tries, 1)
 
         t2.clear(upstream=True)
         ti1.run()
         ti2.run()
-        self.assertEqual(ti1.try_number, 1)
+        self.assertEqual(ti1.try_number, 2)
         # max_tries is 0 because there is no task instance in db for ti1
         # so clear won't change the max_tries.
         self.assertEqual(ti1.max_tries, 0)
-        self.assertEqual(ti2.try_number, 1)
+        self.assertEqual(ti2.try_number, 2)
         # try_number (0) + retries(1)
         self.assertEqual(ti2.max_tries, 1)
 

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/4b4e504e/tests/utils/test_log_handlers.py
----------------------------------------------------------------------
diff --git a/tests/utils/test_log_handlers.py b/tests/utils/test_log_handlers.py
index fd5006c..0feb363 100644
--- a/tests/utils/test_log_handlers.py
+++ b/tests/utils/test_log_handlers.py
@@ -12,18 +12,20 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-import copy
 import logging
 import logging.config
-import mock
 import os
 import unittest
+import six
 
-from airflow.models import TaskInstance, DAG
+from airflow.models import TaskInstance, DAG, DagRun
 from airflow.config_templates.airflow_local_settings import DEFAULT_LOGGING_CONFIG
 from airflow.operators.dummy_operator import DummyOperator
+from airflow.operators.python_operator import PythonOperator
 from airflow.utils.timezone import datetime
+from airflow.utils.log.logging_mixin import set_context
 from airflow.utils.log.file_task_handler import FileTaskHandler
+from airflow.utils.db import create_session
 
 DEFAULT_DATE = datetime(2016, 1, 1)
 TASK_LOGGER = 'airflow.task'
@@ -32,10 +34,21 @@ FILE_TASK_HANDLER = 'file.task'
 
 class TestFileTaskLogHandler(unittest.TestCase):
 
+    def cleanUp(self):
+        with create_session() as session:
+            session.query(DagRun).delete()
+            session.query(TaskInstance).delete()
+
     def setUp(self):
         super(TestFileTaskLogHandler, self).setUp()
-        # We use file task handler by default.
         logging.config.dictConfig(DEFAULT_LOGGING_CONFIG)
+        logging.root.disabled = False
+        self.cleanUp()
+        # We use file task handler by default.
+
+    def tearDown(self):
+        self.cleanUp()
+        super(TestFileTaskLogHandler, self).tearDown()
 
     def test_default_task_logging_setup(self):
         # file task handler is used by default.
@@ -46,29 +59,51 @@ class TestFileTaskLogHandler(unittest.TestCase):
         self.assertEqual(handler.name, FILE_TASK_HANDLER)
 
     def test_file_task_handler(self):
+        def task_callable(ti, **kwargs):
+            ti.log.info("test")
         dag = DAG('dag_for_testing_file_task_handler', start_date=DEFAULT_DATE)
-        task = DummyOperator(task_id='task_for_testing_file_log_handler', dag=dag)
+        task = PythonOperator(
+            task_id='task_for_testing_file_log_handler',
+            dag=dag,
+            python_callable=task_callable,
+            provide_context=True
+        )
         ti = TaskInstance(task=task, execution_date=DEFAULT_DATE)
 
-        logger = logging.getLogger(TASK_LOGGER)
+        logger = ti.log
+        ti.log.disabled = False
+
         file_handler = next((handler for handler in logger.handlers
                              if handler.name == FILE_TASK_HANDLER), None)
         self.assertIsNotNone(file_handler)
 
-        file_handler.set_context(ti)
+        set_context(logger, ti)
         self.assertIsNotNone(file_handler.handler)
         # We expect set_context generates a file locally.
         log_filename = file_handler.handler.baseFilename
         self.assertTrue(os.path.isfile(log_filename))
+        self.assertTrue(log_filename.endswith("1.log"), log_filename)
+
+        ti.run(ignore_ti_state=True)
 
-        logger.info("test")
-        ti.run()
+        file_handler.flush()
+        file_handler.close()
 
         self.assertTrue(hasattr(file_handler, 'read'))
         # Return value of read must be a list.
         logs = file_handler.read(ti)
         self.assertTrue(isinstance(logs, list))
         self.assertEqual(len(logs), 1)
+        target_re = r'\n\[[^\]]+\] {test_log_handlers.py:\d+} INFO - test\n'
+
+        # We should expect our log line from the callable above to appear in
+        # the logs we read back
+        six.assertRegex(
+            self,
+            logs[0],
+            target_re,
+            "Logs were " + str(logs)
+        )
 
         # Remove the generated tmp log file.
         os.remove(log_filename)