You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by GitBox <gi...@apache.org> on 2018/09/21 05:00:30 UTC

[GitHub] feng-tao closed pull request #3596: [AIRFLOW-2747] Explicit re-schedule of sensors

feng-tao closed pull request #3596: [AIRFLOW-2747] Explicit re-schedule of sensors
URL: https://github.com/apache/incubator-airflow/pull/3596
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/airflow/exceptions.py b/airflow/exceptions.py
index 89f3d0e048..d4098c4a32 100644
--- a/airflow/exceptions.py
+++ b/airflow/exceptions.py
@@ -47,6 +47,17 @@ class AirflowSensorTimeout(AirflowException):
     pass
 
 
+class AirflowRescheduleException(AirflowException):
+    """
+    Raise when the task should be re-scheduled at a later time.
+
+    :param reschedule_date: The date when the task should be rescheduled
+    :type reschedule: datetime
+    """
+    def __init__(self, reschedule_date):
+        self.reschedule_date = reschedule_date
+
+
 class AirflowTaskTimeout(AirflowException):
     pass
 
diff --git a/airflow/migrations/versions/0a2a5b66e19d_add_task_reschedule_table.py b/airflow/migrations/versions/0a2a5b66e19d_add_task_reschedule_table.py
new file mode 100644
index 0000000000..6eef6a9437
--- /dev/null
+++ b/airflow/migrations/versions/0a2a5b66e19d_add_task_reschedule_table.py
@@ -0,0 +1,83 @@
+# flake8: noqa
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""add task_reschedule table
+
+Revision ID: 0a2a5b66e19d
+Revises: 9635ae0956e7
+Create Date: 2018-06-17 22:50:00.053620
+
+"""
+
+# revision identifiers, used by Alembic.
+revision = '0a2a5b66e19d'
+down_revision = '9635ae0956e7'
+branch_labels = None
+depends_on = None
+
+from alembic import op
+import sqlalchemy as sa
+from sqlalchemy.dialects import mysql
+
+
+TABLE_NAME = 'task_reschedule'
+INDEX_NAME = 'idx_' + TABLE_NAME + '_dag_task_date'
+
+def mysql_timestamp():
+    return mysql.TIMESTAMP(fsp=6)
+
+def sa_timestamp():
+    return sa.TIMESTAMP(timezone=True)
+
+def upgrade():
+    # See 0e2a74e0fc9f_add_time_zone_awareness
+    conn = op.get_bind()
+    if conn.dialect.name == 'mysql':
+        timestamp = mysql_timestamp
+    else:
+        timestamp = sa_timestamp
+
+    op.create_table(
+        TABLE_NAME,
+        sa.Column('id', sa.Integer(), nullable=False),
+        sa.Column('task_id', sa.String(length=250), nullable=False),
+        sa.Column('dag_id', sa.String(length=250), nullable=False),
+        # use explicit server_default=None otherwise mysql implies defaults for first timestamp column
+        sa.Column('execution_date', timestamp(), nullable=False, server_default=None),
+        sa.Column('try_number', sa.Integer(), nullable=False),
+        sa.Column('start_date', timestamp(), nullable=False),
+        sa.Column('end_date', timestamp(), nullable=False),
+        sa.Column('duration', sa.Integer(), nullable=False),
+        sa.Column('reschedule_date', timestamp(), nullable=False),
+        sa.PrimaryKeyConstraint('id'),
+        sa.ForeignKeyConstraint(['task_id', 'dag_id', 'execution_date'],
+                                ['task_instance.task_id', 'task_instance.dag_id','task_instance.execution_date'],
+                                name='task_reschedule_dag_task_date_fkey')
+    )
+    op.create_index(
+        INDEX_NAME,
+        TABLE_NAME,
+        ['dag_id', 'task_id', 'execution_date'],
+        unique=False
+    )
+
+
+def downgrade():
+    op.drop_index(INDEX_NAME, table_name=TABLE_NAME)
+    op.drop_table(TABLE_NAME)
diff --git a/airflow/models.py b/airflow/models.py
index d703810a77..c6787c693c 100755
--- a/airflow/models.py
+++ b/airflow/models.py
@@ -53,11 +53,11 @@
 import uuid
 from datetime import datetime
 from urllib.parse import urlparse, quote, parse_qsl
-
 from sqlalchemy import (
-    Column, Integer, String, DateTime, Text, Boolean, ForeignKey, PickleType,
-    Index, Float, LargeBinary, UniqueConstraint)
-from sqlalchemy import func, or_, and_, true as sqltrue
+    Boolean, Column, DateTime, Float, ForeignKey, ForeignKeyConstraint, Index,
+    Integer, LargeBinary, PickleType, String, Text, UniqueConstraint,
+    and_, asc, func, or_, true as sqltrue
+)
 from sqlalchemy.ext.declarative import declarative_base, declared_attr
 from sqlalchemy.orm import reconstructor, relationship, synonym
 
@@ -70,7 +70,8 @@
 from airflow.executors import GetDefaultExecutor, LocalExecutor
 from airflow import configuration
 from airflow.exceptions import (
-    AirflowDagCycleException, AirflowException, AirflowSkipException, AirflowTaskTimeout
+    AirflowDagCycleException, AirflowException, AirflowSkipException, AirflowTaskTimeout,
+    AirflowRescheduleException
 )
 from airflow.dag.base_dag import BaseDag, BaseDagBag
 from airflow.lineage import apply_lineage, prepare_lineage
@@ -1673,6 +1674,10 @@ def signal_handler(signum, frame):
         except AirflowSkipException:
             self.refresh_from_db(lock_for_update=True)
             self.state = State.SKIPPED
+        except AirflowRescheduleException as reschedule_exception:
+            self.refresh_from_db()
+            self._handle_reschedule(reschedule_exception, test_mode, context)
+            return
         except AirflowException as e:
             self.refresh_from_db()
             # for case when task is marked as success/failed externally
@@ -1744,6 +1749,32 @@ def dry_run(self):
         self.render_templates()
         task_copy.dry_run()
 
+    @provide_session
+    def _handle_reschedule(self, reschedule_exception, test_mode=False, context=None,
+                           session=None):
+        # Don't record reschedule request in test mode
+        if test_mode:
+            return
+
+        self.end_date = timezone.utcnow()
+        self.set_duration()
+
+        # Log reschedule request
+        session.add(TaskReschedule(self.task, self.execution_date, self._try_number,
+                    self.start_date, self.end_date,
+                    reschedule_exception.reschedule_date))
+
+        # set state
+        self.state = State.NONE
+
+        # Decrement try_number so subsequent runs will use the same try number and write
+        # to same log file.
+        self._try_number -= 1
+
+        session.merge(self)
+        session.commit()
+        self.log.info('Rescheduling task, marking task as NONE')
+
     @provide_session
     def handle_failure(self, error, test_mode=False, context=None, session=None):
         self.log.exception(error)
@@ -2101,6 +2132,66 @@ def __init__(self, task, execution_date, start_date, end_date):
             self.duration = None
 
 
+class TaskReschedule(Base):
+    """
+    TaskReschedule tracks rescheduled task instances.
+    """
+
+    __tablename__ = "task_reschedule"
+
+    id = Column(Integer, primary_key=True)
+    task_id = Column(String(ID_LEN), nullable=False)
+    dag_id = Column(String(ID_LEN), nullable=False)
+    execution_date = Column(UtcDateTime, nullable=False)
+    try_number = Column(Integer, nullable=False)
+    start_date = Column(UtcDateTime, nullable=False)
+    end_date = Column(UtcDateTime, nullable=False)
+    duration = Column(Integer, nullable=False)
+    reschedule_date = Column(UtcDateTime, nullable=False)
+
+    __table_args__ = (
+        Index('idx_task_reschedule_dag_task_date', dag_id, task_id, execution_date,
+              unique=False),
+        ForeignKeyConstraint([task_id, dag_id, execution_date],
+                             [TaskInstance.task_id, TaskInstance.dag_id,
+                              TaskInstance.execution_date],
+                             name='task_reschedule_dag_task_date_fkey')
+    )
+
+    def __init__(self, task, execution_date, try_number, start_date, end_date,
+                 reschedule_date):
+        self.dag_id = task.dag_id
+        self.task_id = task.task_id
+        self.execution_date = execution_date
+        self.try_number = try_number
+        self.start_date = start_date
+        self.end_date = end_date
+        self.reschedule_date = reschedule_date
+        self.duration = (self.end_date - self.start_date).total_seconds()
+
+    @staticmethod
+    @provide_session
+    def find_for_task_instance(task_instance, session):
+        """
+        Returns all task reschedules for the task instance and try number,
+        in ascending order.
+
+        :param task_instance: the task instance to find task reschedules for
+        :type task_instance: TaskInstance
+        """
+        TR = TaskReschedule
+        return (
+            session
+            .query(TR)
+            .filter(TR.dag_id == task_instance.dag_id,
+                    TR.task_id == task_instance.task_id,
+                    TR.execution_date == task_instance.execution_date,
+                    TR.try_number == task_instance.try_number)
+            .order_by(asc(TR.id))
+            .all()
+        )
+
+
 class Log(Base):
     """
     Used to actively log events to the database
@@ -5066,12 +5157,13 @@ def update_state(self, session=None):
             no_dependencies_met = True
             for ut in unfinished_tasks:
                 # We need to flag upstream and check for changes because upstream
-                # failures can result in deadlock false positives
+                # failures/re-schedules can result in deadlock false positives
                 old_state = ut.state
                 deps_met = ut.are_dependencies_met(
                     dep_context=DepContext(
                         flag_upstream_failed=True,
-                        ignore_in_retry_period=True),
+                        ignore_in_retry_period=True,
+                        ignore_in_reschedule_period=True),
                     session=session)
                 if deps_met or old_state != ut.current_state(session=session):
                     no_dependencies_met = False
diff --git a/airflow/sensors/base_sensor_operator.py b/airflow/sensors/base_sensor_operator.py
index 74b0e0fe1c..1dc59dd230 100644
--- a/airflow/sensors/base_sensor_operator.py
+++ b/airflow/sensors/base_sensor_operator.py
@@ -19,20 +19,22 @@
 
 
 from time import sleep
+from datetime import timedelta
 
 from airflow.exceptions import AirflowException, AirflowSensorTimeout, \
-    AirflowSkipException
-from airflow.models import BaseOperator, SkipMixin
+    AirflowSkipException, AirflowRescheduleException
+from airflow.models import BaseOperator, SkipMixin, TaskReschedule
 from airflow.utils import timezone
 from airflow.utils.decorators import apply_defaults
+from airflow.ti_deps.deps.ready_to_reschedule import ReadyToRescheduleDep
 
 
 class BaseSensorOperator(BaseOperator, SkipMixin):
     """
-    Sensor operators are derived from this class an inherit these attributes.
+    Sensor operators are derived from this class and inherit these attributes.
 
     Sensor operators keep executing at a time interval and succeed when
-        a criteria is met and fail if and when they time out.
+    a criteria is met and fail if and when they time out.
 
     :param soft_fail: Set to true to mark the task as SKIPPED on failure
     :type soft_fail: bool
@@ -41,20 +43,42 @@ class BaseSensorOperator(BaseOperator, SkipMixin):
     :type poke_interval: int
     :param timeout: Time, in seconds before the task times out and fails.
     :type timeout: int
+    :param mode: How the sensor operates.
+        Options are: ``{ poke | reschedule }``, default is ``poke``.
+        When set to ``poke`` the sensor is taking up a worker slot for its
+        whole execution time and sleeps between pokes. Use this mode if the
+        expected runtime of the sensor is short or if a short poke interval
+        is requried.
+        When set to ``reschedule`` the sensor task frees the worker slot when
+        the criteria is not yet met and it's rescheduled at a later time. Use
+        this mode if the expected time until the criteria is met is. The poke
+        inteval should be more than one minute to prevent too much load on
+        the scheduler.
+    :type mode: str
     """
     ui_color = '#e6f1f2'
+    valid_modes = ['poke', 'reschedule']
 
     @apply_defaults
     def __init__(self,
                  poke_interval=60,
                  timeout=60 * 60 * 24 * 7,
                  soft_fail=False,
+                 mode='poke',
                  *args,
                  **kwargs):
         super(BaseSensorOperator, self).__init__(*args, **kwargs)
         self.poke_interval = poke_interval
         self.soft_fail = soft_fail
         self.timeout = timeout
+        if mode not in self.valid_modes:
+            raise AirflowException(
+                "The mode must be one of {valid_modes},"
+                "'{d}.{t}'; received '{m}'."
+                .format(valid_modes=self.valid_modes,
+                        d=self.dag.dag_id if self.dag else "",
+                        t=self.task_id, m=mode))
+        self.mode = mode
 
     def poke(self, context):
         """
@@ -65,6 +89,11 @@ def poke(self, context):
 
     def execute(self, context):
         started_at = timezone.utcnow()
+        if self.reschedule:
+            # If reschedule, use first start date of current try
+            task_reschedules = TaskReschedule.find_for_task_instance(context['ti'])
+            if task_reschedules:
+                started_at = task_reschedules[0].start_date
         while not self.poke(context):
             if (timezone.utcnow() - started_at).total_seconds() > self.timeout:
                 # If sensor is in soft fail mode but will be retried then
@@ -75,7 +104,12 @@ def execute(self, context):
                     raise AirflowSkipException('Snap. Time is OUT.')
                 else:
                     raise AirflowSensorTimeout('Snap. Time is OUT.')
-            sleep(self.poke_interval)
+            if self.reschedule:
+                reschedule_date = timezone.utcnow() + timedelta(
+                    seconds=self.poke_interval)
+                raise AirflowRescheduleException(reschedule_date)
+            else:
+                sleep(self.poke_interval)
         self.log.info("Success criteria met. Exiting.")
 
     def _do_skip_downstream_tasks(self, context):
@@ -83,3 +117,15 @@ def _do_skip_downstream_tasks(self, context):
         self.log.debug("Downstream task_ids %s", downstream_tasks)
         if downstream_tasks:
             self.skip(context['dag_run'], context['ti'].execution_date, downstream_tasks)
+
+    @property
+    def reschedule(self):
+        return self.mode == 'reschedule'
+
+    @property
+    def deps(self):
+        """
+        Adds one additional dependency for all sensor operators that
+        checks if a sensor task instance can be rescheduled.
+        """
+        return BaseOperator.deps.fget(self) | {ReadyToRescheduleDep()}
diff --git a/airflow/ti_deps/dep_context.py b/airflow/ti_deps/dep_context.py
index a0f30831e5..60d4118d84 100644
--- a/airflow/ti_deps/dep_context.py
+++ b/airflow/ti_deps/dep_context.py
@@ -58,6 +58,8 @@ class DepContext(object):
     :type ignore_depends_on_past: bool
     :param ignore_in_retry_period: Ignore the retry period for task instances
     :type ignore_in_retry_period: bool
+    :param ignore_in_reschedule_period: Ignore the reschedule period for task instances
+    :type ignore_in_reschedule_period: bool
     :param ignore_task_deps: Ignore task-specific dependencies such as depends_on_past and
         trigger rule
     :type ignore_task_deps: bool
@@ -71,6 +73,7 @@ def __init__(
             ignore_all_deps=False,
             ignore_depends_on_past=False,
             ignore_in_retry_period=False,
+            ignore_in_reschedule_period=False,
             ignore_task_deps=False,
             ignore_ti_state=False):
         self.deps = deps or set()
@@ -78,6 +81,7 @@ def __init__(
         self.ignore_all_deps = ignore_all_deps
         self.ignore_depends_on_past = ignore_depends_on_past
         self.ignore_in_retry_period = ignore_in_retry_period
+        self.ignore_in_reschedule_period = ignore_in_reschedule_period
         self.ignore_task_deps = ignore_task_deps
         self.ignore_ti_state = ignore_ti_state
 
diff --git a/airflow/ti_deps/deps/ready_to_reschedule.py b/airflow/ti_deps/deps/ready_to_reschedule.py
new file mode 100644
index 0000000000..e0f5f8fdfe
--- /dev/null
+++ b/airflow/ti_deps/deps/ready_to_reschedule.py
@@ -0,0 +1,69 @@
+# -*- coding: utf-8 -*-
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from airflow.ti_deps.deps.base_ti_dep import BaseTIDep
+from airflow.utils import timezone
+from airflow.utils.db import provide_session
+from airflow.utils.state import State
+
+
+class ReadyToRescheduleDep(BaseTIDep):
+    NAME = "Ready To Reschedule"
+    IGNOREABLE = True
+    IS_TASK_DEP = True
+
+    @provide_session
+    def _get_dep_statuses(self, ti, session, dep_context):
+        """
+        Determines whether a task is ready to be rescheduled. Only tasks in
+        NONE state with at least one row in task_reschedule table are
+        handled by this dependency class, otherwise this dependency is
+        considered as passed. This dependency fails if the latest reschedule
+        request's reschedule date is still in future.
+        """
+        if dep_context.ignore_in_reschedule_period:
+            yield self._passing_status(
+                reason="The context specified that being in a reschedule period was "
+                       "permitted.")
+            return
+
+        if ti.state != State.NONE:
+            yield self._passing_status(
+                reason="The task instance is not in NONE state.")
+            return
+
+        # Lazy import to avoid circular dependency
+        from airflow.models import TaskReschedule
+        task_reschedules = TaskReschedule.find_for_task_instance(task_instance=ti)
+        if not task_reschedules:
+            yield self._passing_status(
+                reason="There is no reschedule request for this task instance.")
+            return
+
+        now = timezone.utcnow()
+        next_reschedule_date = task_reschedules[-1].reschedule_date
+        if now >= next_reschedule_date:
+            yield self._passing_status(
+                reason="Task instance id ready for reschedule.")
+            return
+
+        yield self._failing_status(
+            reason="Task is not ready for reschedule yet but will be rescheduled "
+                   "automatically. Current date is {0} and task will be rescheduled "
+                   "at {1}.".format(now.isoformat(), next_reschedule_date.isoformat()))
diff --git a/airflow/www/static/gantt-chart-d3v2.js b/airflow/www/static/gantt-chart-d3v2.js
index d21311a1c5..245a0147e9 100644
--- a/airflow/www/static/gantt-chart-d3v2.js
+++ b/airflow/www/static/gantt-chart-d3v2.js
@@ -129,7 +129,7 @@ d3.gantt = function() {
       call_modal(d.taskName, d.executionDate);
     })
     .attr("class", function(d){
-      if(taskStatus[d.status] == null){ return "bar";}
+      if(taskStatus[d.status] == null){ return "null";}
       return taskStatus[d.status];
     })
     .attr("y", 0)
diff --git a/airflow/www/views.py b/airflow/www/views.py
index aa2530e458..6d5794a89b 100644
--- a/airflow/www/views.py
+++ b/airflow/www/views.py
@@ -1927,21 +1927,48 @@ def gantt(self, session=None):
                     TF.execution_date == ti.execution_date)
             .all()
         ) for ti in tis]))
-        tis_with_fails = sorted(tis + ti_fails, key=lambda ti: ti.start_date)
+        TR = models.TaskReschedule
+        ti_reschedules = list(itertools.chain(*[(
+            session
+            .query(TR)
+            .filter(TR.dag_id == ti.dag_id,
+                    TR.task_id == ti.task_id,
+                    TR.execution_date == ti.execution_date)
+            .all()
+        ) for ti in tis]))
+        # determine bars to show in the gantt chart
+        # all reschedules of one attempt are combinded into one bar
+        gantt_bar_items = []
+        for task_id, items in itertools.groupby(
+                sorted(tis + ti_fails + ti_reschedules, key=lambda ti: ti.task_id),
+                key=lambda ti: ti.task_id):
+            start_date = None
+            for i in sorted(items, key=lambda ti: ti.start_date):
+                start_date = start_date or i.start_date
+                end_date = i.end_date or timezone.utcnow()
+                if type(i) == models.TaskInstance:
+                    gantt_bar_items.append((task_id, start_date, end_date, i.state))
+                    start_date = None
+                elif type(i) == TF and (len(gantt_bar_items) == 0 or
+                                        end_date != gantt_bar_items[-1][2]):
+                    gantt_bar_items.append((task_id, start_date, end_date, State.FAILED))
+                    start_date = None
 
         tasks = []
-        for ti in tis_with_fails:
-            end_date = ti.end_date if ti.end_date else timezone.utcnow()
-            state = ti.state if type(ti) == models.TaskInstance else State.FAILED
+        for gantt_bar_item in gantt_bar_items:
+            task_id = gantt_bar_item[0]
+            start_date = gantt_bar_item[1]
+            end_date = gantt_bar_item[2]
+            state = gantt_bar_item[3]
             tasks.append({
-                'startDate': wwwutils.epoch(ti.start_date),
+                'startDate': wwwutils.epoch(start_date),
                 'endDate': wwwutils.epoch(end_date),
-                'isoStart': ti.start_date.isoformat()[:-4],
+                'isoStart': start_date.isoformat()[:-4],
                 'isoEnd': end_date.isoformat()[:-4],
-                'taskName': ti.task_id,
-                'duration': "{}".format(end_date - ti.start_date)[:-4],
+                'taskName': task_id,
+                'duration': "{}".format(end_date - start_date)[:-4],
                 'status': state,
-                'executionDate': ti.execution_date.isoformat(),
+                'executionDate': dttm.isoformat(),
             })
         states = {task['status']: task['status'] for task in tasks}
         data = {
diff --git a/airflow/www_rbac/static/js/gantt-chart-d3v2.js b/airflow/www_rbac/static/js/gantt-chart-d3v2.js
index d21311a1c5..245a0147e9 100644
--- a/airflow/www_rbac/static/js/gantt-chart-d3v2.js
+++ b/airflow/www_rbac/static/js/gantt-chart-d3v2.js
@@ -129,7 +129,7 @@ d3.gantt = function() {
       call_modal(d.taskName, d.executionDate);
     })
     .attr("class", function(d){
-      if(taskStatus[d.status] == null){ return "bar";}
+      if(taskStatus[d.status] == null){ return "null";}
       return taskStatus[d.status];
     })
     .attr("y", 0)
diff --git a/airflow/www_rbac/views.py b/airflow/www_rbac/views.py
index 3dc3400968..d1bed8e2a1 100644
--- a/airflow/www_rbac/views.py
+++ b/airflow/www_rbac/views.py
@@ -1677,21 +1677,49 @@ def gantt(self, session=None):
                     TF.execution_date == ti.execution_date)
             .all()
         ) for ti in tis]))
-        tis_with_fails = sorted(tis + ti_fails, key=lambda ti: ti.start_date)
+        TR = models.TaskReschedule
+        ti_reschedules = list(itertools.chain(*[(
+            session
+            .query(TR)
+            .filter(TR.dag_id == ti.dag_id,
+                    TR.task_id == ti.task_id,
+                    TR.execution_date == ti.execution_date)
+            .all()
+        ) for ti in tis]))
+
+        # determine bars to show in the gantt chart
+        # all reschedules of one attempt are combinded into one bar
+        gantt_bar_items = []
+        for task_id, items in itertools.groupby(
+                sorted(tis + ti_fails + ti_reschedules, key=lambda ti: ti.task_id),
+                key=lambda ti: ti.task_id):
+            start_date = None
+            for i in sorted(items, key=lambda ti: ti.start_date):
+                start_date = start_date or i.start_date
+                end_date = i.end_date or timezone.utcnow()
+                if type(i) == models.TaskInstance:
+                    gantt_bar_items.append((task_id, start_date, end_date, i.state))
+                    start_date = None
+                elif type(i) == TF and (len(gantt_bar_items) == 0 or
+                                        end_date != gantt_bar_items[-1][2]):
+                    gantt_bar_items.append((task_id, start_date, end_date, State.FAILED))
+                    start_date = None
 
         tasks = []
-        for ti in tis_with_fails:
-            end_date = ti.end_date if ti.end_date else timezone.utcnow()
-            state = ti.state if type(ti) == models.TaskInstance else State.FAILED
+        for gantt_bar_item in gantt_bar_items:
+            task_id = gantt_bar_item[0]
+            start_date = gantt_bar_item[1]
+            end_date = gantt_bar_item[2]
+            state = gantt_bar_item[3]
             tasks.append({
-                'startDate': wwwutils.epoch(ti.start_date),
+                'startDate': wwwutils.epoch(start_date),
                 'endDate': wwwutils.epoch(end_date),
-                'isoStart': ti.start_date.isoformat()[:-4],
+                'isoStart': start_date.isoformat()[:-4],
                 'isoEnd': end_date.isoformat()[:-4],
-                'taskName': ti.task_id,
-                'duration': "{}".format(end_date - ti.start_date)[:-4],
+                'taskName': task_id,
+                'duration': "{}".format(end_date - start_date)[:-4],
                 'status': state,
-                'executionDate': ti.execution_date.isoformat(),
+                'executionDate': dttm.isoformat(),
             })
         states = {task['status']: task['status'] for task in tasks}
         data = {
diff --git a/tests/sensors/test_base_sensor.py b/tests/sensors/test_base_sensor.py
index adb7a5d1e3..353f4447b1 100644
--- a/tests/sensors/test_base_sensor.py
+++ b/tests/sensors/test_base_sensor.py
@@ -18,17 +18,21 @@
 # under the License.
 
 import unittest
+from mock import Mock
 
 from airflow import DAG, configuration, settings
-from airflow.exceptions import AirflowSensorTimeout
-from airflow.models import DagRun, TaskInstance
+from airflow.exceptions import (AirflowSensorTimeout, AirflowException,
+                                AirflowRescheduleException)
+from airflow.models import DagRun, TaskInstance, TaskReschedule
 from airflow.operators.dummy_operator import DummyOperator
 from airflow.sensors.base_sensor_operator import BaseSensorOperator
+from airflow.ti_deps.deps.ready_to_reschedule import ReadyToRescheduleDep
 from airflow.utils import timezone
 from airflow.utils.state import State
 from airflow.utils.timezone import datetime
 from datetime import timedelta
 from time import sleep
+from freezegun import freeze_time
 
 configuration.load_test_config()
 
@@ -57,6 +61,7 @@ def setUp(self):
         self.dag = DAG(TEST_DAG_ID, default_args=args)
 
         session = settings.Session()
+        session.query(TaskReschedule).delete()
         session.query(DagRun).delete()
         session.query(TaskInstance).delete()
         session.commit()
@@ -158,3 +163,297 @@ def test_soft_fail_with_retries(self):
         self.assertEquals(len(tis), 2)
         for ti in tis:
             self.assertEquals(ti.state, State.SKIPPED)
+
+    def test_ok_with_reschedule(self):
+        sensor = self._make_sensor(
+            return_value=None,
+            poke_interval=10,
+            timeout=25,
+            mode='reschedule')
+        sensor.poke = Mock(side_effect=[False, False, True])
+        dr = self._make_dag_run()
+
+        # first poke returns False and task is re-scheduled
+        date1 = timezone.utcnow()
+        with freeze_time(date1):
+            self._run(sensor)
+        tis = dr.get_task_instances()
+        self.assertEquals(len(tis), 2)
+        for ti in tis:
+            if ti.task_id == SENSOR_OP:
+                # verify task is re-scheduled, i.e. state set to NONE
+                self.assertEquals(ti.state, State.NONE)
+                # verify one row in task_reschedule table
+                task_reschedules = TaskReschedule.find_for_task_instance(ti)
+                self.assertEquals(len(task_reschedules), 1)
+                self.assertEquals(task_reschedules[0].start_date, date1)
+                self.assertEquals(task_reschedules[0].reschedule_date,
+                                  date1 + timedelta(seconds=sensor.poke_interval))
+            if ti.task_id == DUMMY_OP:
+                self.assertEquals(ti.state, State.NONE)
+
+        # second poke returns False and task is re-scheduled
+        date2 = date1 + timedelta(seconds=sensor.poke_interval)
+        with freeze_time(date2):
+            self._run(sensor)
+        tis = dr.get_task_instances()
+        self.assertEquals(len(tis), 2)
+        for ti in tis:
+            if ti.task_id == SENSOR_OP:
+                # verify task is re-scheduled, i.e. state set to NONE
+                self.assertEquals(ti.state, State.NONE)
+                # verify two rows in task_reschedule table
+                task_reschedules = TaskReschedule.find_for_task_instance(ti)
+                self.assertEquals(len(task_reschedules), 2)
+                self.assertEquals(task_reschedules[1].start_date, date2)
+                self.assertEquals(task_reschedules[1].reschedule_date,
+                                  date2 + timedelta(seconds=sensor.poke_interval))
+            if ti.task_id == DUMMY_OP:
+                self.assertEquals(ti.state, State.NONE)
+
+        # third poke returns True and task succeeds
+        date3 = date2 + timedelta(seconds=sensor.poke_interval)
+        with freeze_time(date3):
+            self._run(sensor)
+        tis = dr.get_task_instances()
+        self.assertEquals(len(tis), 2)
+        for ti in tis:
+            if ti.task_id == SENSOR_OP:
+                self.assertEquals(ti.state, State.SUCCESS)
+            if ti.task_id == DUMMY_OP:
+                self.assertEquals(ti.state, State.NONE)
+
+    def test_fail_with_reschedule(self):
+        sensor = self._make_sensor(
+            return_value=False,
+            poke_interval=10,
+            timeout=5,
+            mode='reschedule')
+        dr = self._make_dag_run()
+
+        # first poke returns False and task is re-scheduled
+        date1 = timezone.utcnow()
+        with freeze_time(date1):
+            self._run(sensor)
+        tis = dr.get_task_instances()
+        self.assertEquals(len(tis), 2)
+        for ti in tis:
+            if ti.task_id == SENSOR_OP:
+                self.assertEquals(ti.state, State.NONE)
+            if ti.task_id == DUMMY_OP:
+                self.assertEquals(ti.state, State.NONE)
+
+        # second poke returns False, timeout occurs
+        date2 = date1 + timedelta(seconds=sensor.poke_interval)
+        with freeze_time(date2):
+            with self.assertRaises(AirflowSensorTimeout):
+                self._run(sensor)
+        tis = dr.get_task_instances()
+        self.assertEquals(len(tis), 2)
+        for ti in tis:
+            if ti.task_id == SENSOR_OP:
+                self.assertEquals(ti.state, State.FAILED)
+            if ti.task_id == DUMMY_OP:
+                self.assertEquals(ti.state, State.NONE)
+
+    def test_soft_fail_with_reschedule(self):
+        sensor = self._make_sensor(
+            return_value=False,
+            poke_interval=10,
+            timeout=5,
+            soft_fail=True,
+            mode='reschedule')
+        dr = self._make_dag_run()
+
+        # first poke returns False and task is re-scheduled
+        date1 = timezone.utcnow()
+        with freeze_time(date1):
+            self._run(sensor)
+        tis = dr.get_task_instances()
+        self.assertEquals(len(tis), 2)
+        for ti in tis:
+            if ti.task_id == SENSOR_OP:
+                self.assertEquals(ti.state, State.NONE)
+            if ti.task_id == DUMMY_OP:
+                self.assertEquals(ti.state, State.NONE)
+
+        # second poke returns False, timeout occurs
+        date2 = date1 + timedelta(seconds=sensor.poke_interval)
+        with freeze_time(date2):
+            self._run(sensor)
+        tis = dr.get_task_instances()
+        self.assertEquals(len(tis), 2)
+        for ti in tis:
+            self.assertEquals(ti.state, State.SKIPPED)
+
+    def test_ok_with_reschedule_and_retry(self):
+        sensor = self._make_sensor(
+            return_value=None,
+            poke_interval=10,
+            timeout=5,
+            retries=1,
+            retry_delay=timedelta(seconds=10),
+            mode='reschedule')
+        sensor.poke = Mock(side_effect=[False, False, False, True])
+        dr = self._make_dag_run()
+
+        # first poke returns False and task is re-scheduled
+        date1 = timezone.utcnow()
+        with freeze_time(date1):
+            self._run(sensor)
+        tis = dr.get_task_instances()
+        self.assertEquals(len(tis), 2)
+        for ti in tis:
+            if ti.task_id == SENSOR_OP:
+                self.assertEquals(ti.state, State.NONE)
+                # verify one row in task_reschedule table
+                task_reschedules = TaskReschedule.find_for_task_instance(ti)
+                self.assertEquals(len(task_reschedules), 1)
+                self.assertEquals(task_reschedules[0].start_date, date1)
+                self.assertEquals(task_reschedules[0].reschedule_date,
+                                  date1 + timedelta(seconds=sensor.poke_interval))
+                self.assertEqual(task_reschedules[0].try_number, 1)
+            if ti.task_id == DUMMY_OP:
+                self.assertEquals(ti.state, State.NONE)
+
+        # second poke fails and task instance is marked up to retry
+        date2 = date1 + timedelta(seconds=sensor.poke_interval)
+        with freeze_time(date2):
+            with self.assertRaises(AirflowSensorTimeout):
+                self._run(sensor)
+        tis = dr.get_task_instances()
+        self.assertEquals(len(tis), 2)
+        for ti in tis:
+            if ti.task_id == SENSOR_OP:
+                self.assertEquals(ti.state, State.UP_FOR_RETRY)
+            if ti.task_id == DUMMY_OP:
+                self.assertEquals(ti.state, State.NONE)
+
+        # third poke returns False and task is rescheduled again
+        date3 = date2 + timedelta(seconds=sensor.poke_interval) + sensor.retry_delay
+        with freeze_time(date3):
+            self._run(sensor)
+        tis = dr.get_task_instances()
+        self.assertEquals(len(tis), 2)
+        for ti in tis:
+            if ti.task_id == SENSOR_OP:
+                self.assertEquals(ti.state, State.NONE)
+                # verify one row in task_reschedule table
+                task_reschedules = TaskReschedule.find_for_task_instance(ti)
+                self.assertEquals(len(task_reschedules), 1)
+                self.assertEquals(task_reschedules[0].start_date, date3)
+                self.assertEquals(task_reschedules[0].reschedule_date,
+                                  date3 + timedelta(seconds=sensor.poke_interval))
+                self.assertEqual(task_reschedules[0].try_number, 2)
+            if ti.task_id == DUMMY_OP:
+                self.assertEquals(ti.state, State.NONE)
+
+        # fourth poke return True and task succeeds
+        date4 = date3 + timedelta(seconds=sensor.poke_interval)
+        with freeze_time(date4):
+            self._run(sensor)
+        tis = dr.get_task_instances()
+        self.assertEquals(len(tis), 2)
+        for ti in tis:
+            if ti.task_id == SENSOR_OP:
+                self.assertEquals(ti.state, State.SUCCESS)
+            if ti.task_id == DUMMY_OP:
+                self.assertEquals(ti.state, State.NONE)
+
+    def test_should_include_ready_to_reschedule_dep(self):
+        sensor = self._make_sensor(True)
+        deps = sensor.deps
+        self.assertTrue(ReadyToRescheduleDep() in deps)
+
+    def test_invalid_mode(self):
+        with self.assertRaises(AirflowException):
+            self._make_sensor(
+                return_value=True,
+                mode='foo')
+
+    def test_ok_with_custom_reschedule_exception(self):
+        sensor = self._make_sensor(
+            return_value=None,
+            mode='reschedule')
+        date1 = timezone.utcnow()
+        date2 = date1 + timedelta(seconds=60)
+        date3 = date1 + timedelta(seconds=120)
+        sensor.poke = Mock(side_effect=[
+            AirflowRescheduleException(date2),
+            AirflowRescheduleException(date3),
+            True,
+        ])
+        dr = self._make_dag_run()
+
+        # first poke returns False and task is re-scheduled
+        with freeze_time(date1):
+            self._run(sensor)
+        tis = dr.get_task_instances()
+        self.assertEquals(len(tis), 2)
+        for ti in tis:
+            if ti.task_id == SENSOR_OP:
+                # verify task is re-scheduled, i.e. state set to NONE
+                self.assertEquals(ti.state, State.NONE)
+                # verify one row in task_reschedule table
+                task_reschedules = TaskReschedule.find_for_task_instance(ti)
+                self.assertEquals(len(task_reschedules), 1)
+                self.assertEquals(task_reschedules[0].start_date, date1)
+                self.assertEquals(task_reschedules[0].reschedule_date, date2)
+            if ti.task_id == DUMMY_OP:
+                self.assertEquals(ti.state, State.NONE)
+
+        # second poke returns False and task is re-scheduled
+        with freeze_time(date2):
+            self._run(sensor)
+        tis = dr.get_task_instances()
+        self.assertEquals(len(tis), 2)
+        for ti in tis:
+            if ti.task_id == SENSOR_OP:
+                # verify task is re-scheduled, i.e. state set to NONE
+                self.assertEquals(ti.state, State.NONE)
+                # verify two rows in task_reschedule table
+                task_reschedules = TaskReschedule.find_for_task_instance(ti)
+                self.assertEquals(len(task_reschedules), 2)
+                self.assertEquals(task_reschedules[1].start_date, date2)
+                self.assertEquals(task_reschedules[1].reschedule_date, date3)
+            if ti.task_id == DUMMY_OP:
+                self.assertEquals(ti.state, State.NONE)
+
+        # third poke returns True and task succeeds
+        with freeze_time(date3):
+            self._run(sensor)
+        tis = dr.get_task_instances()
+        self.assertEquals(len(tis), 2)
+        for ti in tis:
+            if ti.task_id == SENSOR_OP:
+                self.assertEquals(ti.state, State.SUCCESS)
+            if ti.task_id == DUMMY_OP:
+                self.assertEquals(ti.state, State.NONE)
+
+    def test_reschedule_with_test_mode(self):
+        sensor = self._make_sensor(
+            return_value=None,
+            poke_interval=10,
+            timeout=25,
+            mode='reschedule')
+        sensor.poke = Mock(side_effect=[False])
+        dr = self._make_dag_run()
+
+        # poke returns False and AirflowRescheduleException is raised
+        date1 = timezone.utcnow()
+        with freeze_time(date1):
+            for dt in self.dag.date_range(DEFAULT_DATE, end_date=DEFAULT_DATE):
+                TaskInstance(sensor, dt).run(
+                    ignore_ti_state=True,
+                    test_mode=True)
+        tis = dr.get_task_instances()
+        self.assertEquals(len(tis), 2)
+        for ti in tis:
+            if ti.task_id == SENSOR_OP:
+                # in test mode state is not modified
+                self.assertEquals(ti.state, State.NONE)
+                # in test mode no reschedule request is recorded
+                task_reschedules = TaskReschedule.find_for_task_instance(ti)
+                self.assertEquals(len(task_reschedules), 0)
+            if ti.task_id == DUMMY_OP:
+                self.assertEquals(ti.state, State.NONE)
diff --git a/tests/ti_deps/deps/test_ready_to_reschedule_dep.py b/tests/ti_deps/deps/test_ready_to_reschedule_dep.py
new file mode 100644
index 0000000000..898850f8b7
--- /dev/null
+++ b/tests/ti_deps/deps/test_ready_to_reschedule_dep.py
@@ -0,0 +1,94 @@
+# -*- coding: utf-8 -*-
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import unittest
+from datetime import timedelta
+from mock import Mock, patch
+
+from airflow.models import TaskInstance, DAG, TaskReschedule
+from airflow.ti_deps.dep_context import DepContext
+from airflow.ti_deps.deps.ready_to_reschedule import ReadyToRescheduleDep
+from airflow.utils.state import State
+from airflow.utils.timezone import utcnow
+
+
+class NotInReschedulePeriodDepTest(unittest.TestCase):
+
+    def _get_task_instance(self, state):
+        dag = DAG('test_dag')
+        task = Mock(dag=dag)
+        ti = TaskInstance(task=task, state=state, execution_date=None)
+        return ti
+
+    def _get_task_reschedule(self, reschedule_date):
+        task = Mock(dag_id='test_dag', task_id='test_task')
+        tr = TaskReschedule(task=task, execution_date=None, try_number=None,
+                            start_date=reschedule_date, end_date=reschedule_date,
+                            reschedule_date=reschedule_date)
+        return tr
+
+    def test_should_pass_if_ignore_in_reschedule_period_is_set(self):
+        ti = self._get_task_instance(State.NONE)
+        dep_context = DepContext(ignore_in_reschedule_period=True)
+        self.assertTrue(ReadyToRescheduleDep().is_met(ti=ti, dep_context=dep_context))
+
+    def test_should_pass_if_not_in_none_state(self):
+        ti = self._get_task_instance(State.UP_FOR_RETRY)
+        self.assertTrue(ReadyToRescheduleDep().is_met(ti=ti))
+
+    @patch('airflow.models.TaskReschedule.find_for_task_instance', return_value=[])
+    def test_should_pass_if_no_reschedule_record_exists(self, find_for_task_instance):
+        ti = self._get_task_instance(State.NONE)
+        self.assertTrue(ReadyToRescheduleDep().is_met(ti=ti))
+
+    @patch('airflow.models.TaskReschedule.find_for_task_instance')
+    def test_should_pass_after_reschedule_date_one(self, find_for_task_instance):
+        find_for_task_instance.return_value = [
+            self._get_task_reschedule(utcnow() - timedelta(minutes=1)),
+        ]
+        ti = self._get_task_instance(State.NONE)
+        self.assertTrue(ReadyToRescheduleDep().is_met(ti=ti))
+
+    @patch('airflow.models.TaskReschedule.find_for_task_instance')
+    def test_should_pass_after_reschedule_date_multiple(self, find_for_task_instance):
+        find_for_task_instance.return_value = [
+            self._get_task_reschedule(utcnow() - timedelta(minutes=21)),
+            self._get_task_reschedule(utcnow() - timedelta(minutes=11)),
+            self._get_task_reschedule(utcnow() - timedelta(minutes=1)),
+        ]
+        ti = self._get_task_instance(State.NONE)
+        self.assertTrue(ReadyToRescheduleDep().is_met(ti=ti))
+
+    @patch('airflow.models.TaskReschedule.find_for_task_instance')
+    def test_should_fail_before_reschedule_date_one(self, find_for_task_instance):
+        find_for_task_instance.return_value = [
+            self._get_task_reschedule(utcnow() + timedelta(minutes=1)),
+        ]
+        ti = self._get_task_instance(State.NONE)
+        self.assertFalse(ReadyToRescheduleDep().is_met(ti=ti))
+
+    @patch('airflow.models.TaskReschedule.find_for_task_instance')
+    def test_should_fail_before_reschedule_date_multiple(self, find_for_task_instance):
+        find_for_task_instance.return_value = [
+            self._get_task_reschedule(utcnow() - timedelta(minutes=19)),
+            self._get_task_reschedule(utcnow() - timedelta(minutes=9)),
+            self._get_task_reschedule(utcnow() + timedelta(minutes=1)),
+        ]
+        ti = self._get_task_instance(State.NONE)
+        self.assertFalse(ReadyToRescheduleDep().is_met(ti=ti))


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services